From 900ac49e606260fb4bb5029ab62f75413f3e716c Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 15:33:20 +0000 Subject: [PATCH] Refactor tests (#1011) * Add parallel test deps * Update signature * Add encoders tests * Update gitignore * Update encoders for timm-universal * Add parallel tests run * Disable models tests * Add uv to CI * Add uv to minimum * Add show-install-packages * Increase to 3 workers * Fix show-packages * Change back for 2 workers * Add coverage * Basic model test * Fix * Move model archs * Add base params test * Fix timm test for minimum version * Remove deprecated utils from coverage * Fix * Fix * Exclude conversion script * Add save-load test, add aux head test * Remove custom encoder * Set encoder for models tests * Docs + flag for anyres * Fix loading from config * Bump min hf-hub to 0.25.0 * Fix minimal * Add test with hub checkpoint * Fixing minimum * Fix * Fix torch for minimum tests * Update torch version and run-slow * run-slow * Show skipped * [run-slow] Fixing minimum * [run-slow] Fixing minimum * Fix decorator * Raise error * [run-slow] Fixing run slow * [run-slow] Fixing run slow * Run slow tests in separate job * FIx * Fixes * Add device * Bum tolerance * Add device * Fixup --- .github/workflows/tests.yml | 49 ++++- .gitignore | 1 + Makefile | 5 +- misc/generate_test_models.py | 41 ++++ pyproject.toml | 24 ++ requirements/minimum.old | 1 + requirements/test.txt | 4 +- segmentation_models_pytorch/__init__.py | 33 +-- segmentation_models_pytorch/base/hub_mixin.py | 11 + segmentation_models_pytorch/base/model.py | 11 +- .../decoders/deeplabv3/model.py | 3 + .../decoders/fpn/model.py | 2 + .../decoders/linknet/model.py | 2 + .../decoders/manet/model.py | 2 + .../decoders/pan/model.py | 2 + .../decoders/pspnet/model.py | 2 + .../decoders/segformer/model.py | 2 + .../decoders/unet/model.py | 8 +- .../decoders/unetplusplus/model.py | 2 + .../decoders/upernet/model.py | 2 + .../encoders/inceptionv4.py | 2 +- tests/encoders/__init__.py | 0 tests/encoders/base.py | 208 ++++++++++++++++++ .../test_pretrainedmodels_encoders.py | 71 ++++++ tests/encoders/test_smp_encoders.py | 41 ++++ tests/encoders/test_timm_ported_encoders.py | 131 +++++++++++ tests/encoders/test_timm_universal.py | 16 ++ tests/encoders/test_torchvision_encoders.py | 24 ++ tests/models/__init__.py | 0 tests/models/base.py | 206 +++++++++++++++++ tests/models/test_deeplab.py | 16 ++ tests/models/test_fpn.py | 7 + tests/models/test_linknet.py | 7 + tests/models/test_manet.py | 7 + tests/models/test_pan.py | 11 + tests/models/test_psp.py | 9 + tests/models/test_segformer.py | 43 ++++ tests/models/test_unet.py | 7 + tests/models/test_unetplusplus.py | 7 + tests/models/test_upernet.py | 8 + tests/test_models.py | 194 ---------------- tests/utils.py | 47 ++++ 42 files changed, 1046 insertions(+), 223 deletions(-) create mode 100644 misc/generate_test_models.py create mode 100644 tests/encoders/__init__.py create mode 100644 tests/encoders/base.py create mode 100644 tests/encoders/test_pretrainedmodels_encoders.py create mode 100644 tests/encoders/test_smp_encoders.py create mode 100644 tests/encoders/test_timm_ported_encoders.py create mode 100644 tests/encoders/test_timm_universal.py create mode 100644 tests/encoders/test_torchvision_encoders.py create mode 100644 tests/models/__init__.py create mode 100644 tests/models/base.py create mode 100644 tests/models/test_deeplab.py create mode 100644 tests/models/test_fpn.py create mode 100644 tests/models/test_linknet.py create mode 100644 tests/models/test_manet.py create mode 100644 tests/models/test_pan.py create mode 100644 tests/models/test_psp.py create mode 100644 tests/models/test_segformer.py create mode 100644 tests/models/test_unet.py create mode 100644 tests/models/test_unetplusplus.py create mode 100644 tests/models/test_upernet.py delete mode 100644 tests/test_models.py create mode 100644 tests/utils.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e2767433..e9b34d73 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,24 +36,61 @@ jobs: runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + - name: Install dependencies - run: python -m pip install -r requirements/required.txt -r requirements/test.txt - - name: Test with pytest - run: pytest + run: | + python -m pip install uv + python -m uv pip install --system -r requirements/required.txt -r requirements/test.txt + + - name: Show installed packages + run: | + python -m pip list + + - name: Test with PyTest + run: | + pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml -k "not logits_match" + + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + slug: qubvel-org/segmentation_models.pytorch + if: matrix.os == 'macos-latest' && matrix.python-version == '3.12' + + test_logits_match: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install uv + python -m uv pip install --system -r requirements/required.txt -r requirements/test.txt + - name: Test with PyTest + run: RUN_SLOW=1 pytest -v -rsx -n 2 -k "logits_match" minimum: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} + - name: Set up Python uses: actions/setup-python@v5 with: python-version: "3.9" - name: Install dependencies - run: python -m pip install -r requirements/minimum.old -r requirements/test.txt + run: | + python -m pip install uv + python -m uv pip install --system -r requirements/minimum.old -r requirements/test.txt + - name: Show installed packages + run: | + python -m pip list - name: Test with pytest - run: pytest + run: pytest -v -rsx -n 2 -k "not logits_match" diff --git a/.gitignore b/.gitignore index 99a7807d..33db579f 100644 --- a/.gitignore +++ b/.gitignore @@ -93,6 +93,7 @@ venv/ ENV/ env.bak/ venv.bak/ +.vscode/ # Spyder project settings .spyderproject diff --git a/Makefile b/Makefile index 9e974026..a58d230f 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,10 @@ install_dev: .venv .venv/bin/pip install -e ".[test]" test: .venv - .venv/bin/pytest -p no:cacheprovider tests/ + .venv/bin/pytest -v -rsx -n 2 tests/ -k "not logits_match" + +test_all: .venv + RUN_SLOW=1 .venv/bin/pytest -v -rsx -n 2 tests/ table: .venv/bin/python misc/generate_table.py diff --git a/misc/generate_test_models.py b/misc/generate_test_models.py new file mode 100644 index 00000000..61d6bfd0 --- /dev/null +++ b/misc/generate_test_models.py @@ -0,0 +1,41 @@ +import os +import torch +import tempfile +import huggingface_hub +import segmentation_models_pytorch as smp + +HUB_REPO = "smp-test-models" +ENCODER_NAME = "tu-resnet18" + +api = huggingface_hub.HfApi(token=os.getenv("HF_TOKEN")) + +for model_name, model_class in smp.MODEL_ARCHITECTURES_MAPPING.items(): + model = model_class(encoder_name=ENCODER_NAME) + model = model.eval() + + # generate test sample + torch.manual_seed(423553) + sample = torch.rand(1, 3, 256, 256) + + with torch.no_grad(): + output = model(sample) + + with tempfile.TemporaryDirectory() as tmpdir: + # save model + model.save_pretrained(f"{tmpdir}") + + # save input and output + torch.save(sample, f"{tmpdir}/input-tensor.pth") + torch.save(output, f"{tmpdir}/output-tensor.pth") + + # create repo + repo_id = f"{HUB_REPO}/{model_name}-{ENCODER_NAME}" + if not api.repo_exists(repo_id=repo_id): + api.create_repo(repo_id=repo_id, repo_type="model") + + # upload to hub + api.upload_folder( + folder_path=tmpdir, + repo_id=f"{HUB_REPO}/{model_name}-{ENCODER_NAME}", + repo_type="model", + ) diff --git a/pyproject.toml b/pyproject.toml index 3df76c0f..5df18bc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,8 @@ docs = [ ] test = [ 'pytest', + 'pytest-cov', + 'pytest-xdist', 'ruff', ] @@ -55,3 +57,25 @@ version = {attr = 'segmentation_models_pytorch.__version__.__version__'} [tool.setuptools.packages.find] include = ['segmentation_models_pytorch*'] + +[tool.pytest.ini_options] +markers = [ + "deeplabv3", + "deeplabv3plus", + "fpn", + "linknet", + "manet", + "pan", + "psp", + "segformer", + "unet", + "unetplusplus", + "upernet", + "logits_match", +] + +[tool.coverage.run] +omit = [ + "segmentation_models_pytorch/utils/*", + "**/convert_*", +] diff --git a/requirements/minimum.old b/requirements/minimum.old index 3f687871..1080bdb4 100644 --- a/requirements/minimum.old +++ b/requirements/minimum.old @@ -8,3 +8,4 @@ timm==0.9.0 torch==1.9.0 torchvision==0.10.0 tqdm==4.42.1 +Jinja2==3.0.0 diff --git a/requirements/test.txt b/requirements/test.txt index c635aa51..ca126ece 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,2 +1,4 @@ pytest==8.3.4 -ruff==0.8.4 +pytest-xdist==3.6.1 +pytest-cov==6.0.0 +ruff==0.8.4 \ No newline at end of file diff --git a/segmentation_models_pytorch/__init__.py b/segmentation_models_pytorch/__init__.py index 5cde6004..f1807836 100644 --- a/segmentation_models_pytorch/__init__.py +++ b/segmentation_models_pytorch/__init__.py @@ -30,6 +30,21 @@ "ignore", message=r'"is" with \'str\' literal.*', category=SyntaxWarning ) # for python >= 3.12 +_MODEL_ARCHITECTURES = [ + Unet, + UnetPlusPlus, + MAnet, + Linknet, + FPN, + PSPNet, + DeepLabV3, + DeepLabV3Plus, + PAN, + UPerNet, + Segformer, +] +MODEL_ARCHITECTURES_MAPPING = {a.__name__.lower(): a for a in _MODEL_ARCHITECTURES} + def create_model( arch: str, @@ -43,26 +58,12 @@ def create_model( parameters, without using its class """ - archs = [ - Unet, - UnetPlusPlus, - MAnet, - Linknet, - FPN, - PSPNet, - DeepLabV3, - DeepLabV3Plus, - PAN, - UPerNet, - Segformer, - ] - archs_dict = {a.__name__.lower(): a for a in archs} try: - model_class = archs_dict[arch.lower()] + model_class = MODEL_ARCHITECTURES_MAPPING[arch.lower()] except KeyError: raise KeyError( "Wrong architecture type `{}`. Available options are: {}".format( - arch, list(archs_dict.keys()) + arch, list(MODEL_ARCHITECTURES_MAPPING.keys()) ) ) return model_class( diff --git a/segmentation_models_pytorch/base/hub_mixin.py b/segmentation_models_pytorch/base/hub_mixin.py index 8095c5b8..360aa521 100644 --- a/segmentation_models_pytorch/base/hub_mixin.py +++ b/segmentation_models_pytorch/base/hub_mixin.py @@ -136,3 +136,14 @@ def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs): model_class = getattr(smp, model_class_name) return model_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + + +def supports_config_loading(func): + """Decorator to filter special config kwargs""" + + @wraps(func) + def wrapper(self, *args, **kwargs): + kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + return func(self, *args, **kwargs) + + return wrapper diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index 29c7dd2a..6d7bf643 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -5,6 +5,12 @@ class SegmentationModel(torch.nn.Module, SMPHubMixin): + """Base class for all segmentation models.""" + + # if model supports shape not divisible by 2 ^ n + # set to False + requires_divisible_input_shape = True + def initialize(self): init.initialize_decoder(self.decoder) init.initialize_head(self.segmentation_head) @@ -12,6 +18,9 @@ def initialize(self): init.initialize_head(self.classification_head) def check_input_shape(self, x): + """Check if the input shape is divisible by the output stride. + If not, raise a RuntimeError. + """ h, w = x.shape[-2:] output_stride = self.encoder.output_stride if h % output_stride != 0 or w % output_stride != 0: @@ -33,7 +42,7 @@ def check_input_shape(self, x): def forward(self, x): """Sequentially pass `x` trough model`s encoder, decoder and heads""" - if not torch.jit.is_tracing(): + if not torch.jit.is_tracing() or self.requires_divisible_input_shape: self.check_input_shape(x) features = self.encoder(x) diff --git a/segmentation_models_pytorch/decoders/deeplabv3/model.py b/segmentation_models_pytorch/decoders/deeplabv3/model.py index 830906cb..654e38d4 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/model.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/model.py @@ -8,6 +8,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder @@ -54,6 +55,7 @@ class DeepLabV3(SegmentationModel): """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", @@ -163,6 +165,7 @@ class DeepLabV3Plus(SegmentationModel): """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", diff --git a/segmentation_models_pytorch/decoders/fpn/model.py b/segmentation_models_pytorch/decoders/fpn/model.py index 373269c5..7420b289 100644 --- a/segmentation_models_pytorch/decoders/fpn/model.py +++ b/segmentation_models_pytorch/decoders/fpn/model.py @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import FPNDecoder @@ -51,6 +52,7 @@ class FPN(SegmentationModel): """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", diff --git a/segmentation_models_pytorch/decoders/linknet/model.py b/segmentation_models_pytorch/decoders/linknet/model.py index 708ea562..356468ed 100644 --- a/segmentation_models_pytorch/decoders/linknet/model.py +++ b/segmentation_models_pytorch/decoders/linknet/model.py @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import LinknetDecoder @@ -53,6 +54,7 @@ class Linknet(SegmentationModel): https://arxiv.org/abs/1707.03718 """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", diff --git a/segmentation_models_pytorch/decoders/manet/model.py b/segmentation_models_pytorch/decoders/manet/model.py index 6651dee6..6ed59207 100644 --- a/segmentation_models_pytorch/decoders/manet/model.py +++ b/segmentation_models_pytorch/decoders/manet/model.py @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import MAnetDecoder @@ -56,6 +57,7 @@ class MAnet(SegmentationModel): """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", diff --git a/segmentation_models_pytorch/decoders/pan/model.py b/segmentation_models_pytorch/decoders/pan/model.py index 712541a5..6d5e78c2 100644 --- a/segmentation_models_pytorch/decoders/pan/model.py +++ b/segmentation_models_pytorch/decoders/pan/model.py @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import PANDecoder @@ -53,6 +54,7 @@ class PAN(SegmentationModel): """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", diff --git a/segmentation_models_pytorch/decoders/pspnet/model.py b/segmentation_models_pytorch/decoders/pspnet/model.py index dbf04ea4..8b99b3da 100644 --- a/segmentation_models_pytorch/decoders/pspnet/model.py +++ b/segmentation_models_pytorch/decoders/pspnet/model.py @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import PSPDecoder @@ -54,6 +55,7 @@ class PSPNet(SegmentationModel): https://arxiv.org/abs/1612.01105 """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", diff --git a/segmentation_models_pytorch/decoders/segformer/model.py b/segmentation_models_pytorch/decoders/segformer/model.py index 2a5e8dba..45805de7 100644 --- a/segmentation_models_pytorch/decoders/segformer/model.py +++ b/segmentation_models_pytorch/decoders/segformer/model.py @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import SegformerDecoder @@ -46,6 +47,7 @@ class Segformer(SegmentationModel): """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index 0ac7b5bd..547581eb 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union, Tuple, Callable from segmentation_models_pytorch.base import ( ClassificationHead, @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import UnetDecoder @@ -55,17 +56,18 @@ class Unet(SegmentationModel): """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", decoder_use_batchnorm: bool = True, - decoder_channels: List[int] = (256, 128, 64, 32, 16), + decoder_channels: Tuple[int, ...] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py index 9ba72321..9d4a1e35 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/model.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import UnetPlusPlusDecoder @@ -55,6 +56,7 @@ class UnetPlusPlus(SegmentationModel): """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", diff --git a/segmentation_models_pytorch/decoders/upernet/model.py b/segmentation_models_pytorch/decoders/upernet/model.py index de30a7bb..076ed2de 100644 --- a/segmentation_models_pytorch/decoders/upernet/model.py +++ b/segmentation_models_pytorch/decoders/upernet/model.py @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import UPerNetDecoder @@ -47,6 +48,7 @@ class UPerNet(SegmentationModel): """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", diff --git a/segmentation_models_pytorch/encoders/inceptionv4.py b/segmentation_models_pytorch/encoders/inceptionv4.py index 83adf003..96540f9a 100644 --- a/segmentation_models_pytorch/encoders/inceptionv4.py +++ b/segmentation_models_pytorch/encoders/inceptionv4.py @@ -49,7 +49,7 @@ def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): # remove linear layers del self.last_linear - def make_dilated(self, stage_list, dilation_list): + def make_dilated(self, *args, **kwargs): raise ValueError( "InceptionV4 encoder does not support dilated mode " "due to pooling operation for downsampling!" diff --git a/tests/encoders/__init__.py b/tests/encoders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/encoders/base.py b/tests/encoders/base.py new file mode 100644 index 00000000..39cd4164 --- /dev/null +++ b/tests/encoders/base.py @@ -0,0 +1,208 @@ +import unittest +import torch +import segmentation_models_pytorch as smp + +from functools import lru_cache +from tests.utils import default_device + + +class BaseEncoderTester(unittest.TestCase): + encoder_names = [] + + # standard encoder configuration + num_output_features = 6 + output_strides = [1, 2, 4, 8, 16, 32] + supports_dilated = True + + # test sample configuration + default_batch_size = 1 + default_num_channels = 3 + default_height = 64 + default_width = 64 + + # test configurations + in_channels_to_test = [1, 3, 4] + depth_to_test = [3, 4, 5] + strides_to_test = [8, 16] # 32 is a default one + + @lru_cache + def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32): + return torch.rand(batch_size, num_channels, height, width) + + def get_features_output_strides(self, sample, features): + height, width = sample.shape[2:] + height_strides = [height // f.shape[2] for f in features] + width_strides = [width // f.shape[3] for f in features] + return height_strides, width_strides + + def test_forward_backward(self): + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height, + width=self.default_width, + ).to(default_device) + for encoder_name in self.encoder_names: + with self.subTest(encoder_name=encoder_name): + # init encoder + encoder = smp.encoders.get_encoder( + encoder_name, in_channels=3, encoder_weights=None + ).to(default_device) + + # forward + features = encoder.forward(sample) + self.assertEqual( + len(features), + self.num_output_features, + f"Encoder `{encoder_name}` should have {self.num_output_features} output feature maps, but has {len(features)}", + ) + + # backward + features[-1].mean().backward() + + def test_in_channels(self): + cases = [ + (encoder_name, in_channels) + for encoder_name in self.encoder_names + for in_channels in self.in_channels_to_test + ] + + for encoder_name, in_channels in cases: + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=in_channels, + height=self.default_height, + width=self.default_width, + ).to(default_device) + + with self.subTest(encoder_name=encoder_name, in_channels=in_channels): + encoder = smp.encoders.get_encoder( + encoder_name, in_channels=in_channels, encoder_weights=None + ).to(default_device) + encoder.eval() + + # forward + with torch.no_grad(): + encoder.forward(sample) + + def test_depth(self): + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height, + width=self.default_width, + ).to(default_device) + + cases = [ + (encoder_name, depth) + for encoder_name in self.encoder_names + for depth in self.depth_to_test + ] + + for encoder_name, depth in cases: + with self.subTest(encoder_name=encoder_name, depth=depth): + encoder = smp.encoders.get_encoder( + encoder_name, + in_channels=self.default_num_channels, + encoder_weights=None, + depth=depth, + ).to(default_device) + encoder.eval() + + # forward + with torch.no_grad(): + features = encoder.forward(sample) + + # check number of features + self.assertEqual( + len(features), + depth + 1, + f"Encoder `{encoder_name}` should have {depth + 1} output feature maps, but has {len(features)}", + ) + + # check feature strides + height_strides, width_strides = self.get_features_output_strides( + sample, features + ) + self.assertEqual( + height_strides, + self.output_strides[: depth + 1], + f"Encoder `{encoder_name}` should have output strides {self.output_strides[:depth + 1]}, but has {height_strides}", + ) + self.assertEqual( + width_strides, + self.output_strides[: depth + 1], + f"Encoder `{encoder_name}` should have output strides {self.output_strides[:depth + 1]}, but has {width_strides}", + ) + + # check encoder output stride property + self.assertEqual( + encoder.output_stride, + self.output_strides[depth], + f"Encoder `{encoder_name}` last feature map should have output stride {self.output_strides[depth]}, but has {encoder.output_stride}", + ) + + # check out channels also have proper length + self.assertEqual( + len(encoder.out_channels), + depth + 1, + f"Encoder `{encoder_name}` should have {depth + 1} out_channels, but has {len(encoder.out_channels)}", + ) + + def test_dilated(self): + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height, + width=self.default_width, + ).to(default_device) + + cases = [ + (encoder_name, stride) + for encoder_name in self.encoder_names + for stride in self.strides_to_test + ] + + # special case for encoders that do not support dilated model + # just check proper error is raised + if not self.supports_dilated: + with self.assertRaises(ValueError, msg="not support dilated mode"): + encoder_name, stride = cases[0] + encoder = smp.encoders.get_encoder( + encoder_name, + in_channels=self.default_num_channels, + encoder_weights=None, + output_stride=stride, + ).to(default_device) + return + + for encoder_name, stride in cases: + with self.subTest(encoder_name=encoder_name, stride=stride): + encoder = smp.encoders.get_encoder( + encoder_name, + in_channels=self.default_num_channels, + encoder_weights=None, + output_stride=stride, + ).to(default_device) + encoder.eval() + + # forward + with torch.no_grad(): + features = encoder.forward(sample) + + height_strides, width_strides = self.get_features_output_strides( + sample, features + ) + expected_height_strides = [min(stride, s) for s in height_strides] + expected_width_strides = [min(stride, s) for s in width_strides] + + self.assertEqual( + height_strides, + expected_height_strides, + f"Encoder `{encoder_name}` should have height output strides {expected_height_strides}, but has {height_strides}", + ) + self.assertEqual( + width_strides, + expected_width_strides, + f"Encoder `{encoder_name}` should have width output strides {expected_width_strides}, but has {width_strides}", + ) diff --git a/tests/encoders/test_pretrainedmodels_encoders.py b/tests/encoders/test_pretrainedmodels_encoders.py new file mode 100644 index 00000000..bbde576c --- /dev/null +++ b/tests/encoders/test_pretrainedmodels_encoders.py @@ -0,0 +1,71 @@ +from tests.encoders import base +from tests.utils import RUN_ALL_ENCODERS + + +class TestDenseNetEncoder(base.BaseEncoderTester): + supports_dilated = False + encoder_names = ( + ["densenet121"] + if not RUN_ALL_ENCODERS + else ["densenet121", "densenet169", "densenet161"] + ) + + +class TestDPNEncoder(base.BaseEncoderTester): + encoder_names = ( + ["dpn68"] + if not RUN_ALL_ENCODERS + else ["dpn68", "dpn68b", "dpn92", "dpn98", "dpn107", "dpn131"] + ) + + +class TestInceptionResNetV2Encoder(base.BaseEncoderTester): + supports_dilated = False + encoder_names = ( + ["inceptionresnetv2"] if not RUN_ALL_ENCODERS else ["inceptionresnetv2"] + ) + + +class TestInceptionV4Encoder(base.BaseEncoderTester): + supports_dilated = False + encoder_names = ["inceptionv4"] if not RUN_ALL_ENCODERS else ["inceptionv4"] + + +class TestResNetEncoder(base.BaseEncoderTester): + encoder_names = ( + ["resnet18"] + if not RUN_ALL_ENCODERS + else [ + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + "resnext50_32x4d", + "resnext101_32x4d", + "resnext101_32x8d", + "resnext101_32x16d", + "resnext101_32x32d", + "resnext101_32x48d", + ] + ) + + +class TestSeNetEncoder(base.BaseEncoderTester): + encoder_names = ( + ["se_resnet50"] + if not RUN_ALL_ENCODERS + else [ + "se_resnet50", + "se_resnet101", + "se_resnet152", + "se_resnext50_32x4d", + "se_resnext101_32x4d", + # "senet154", # extra large model + ] + ) + + +class TestXceptionEncoder(base.BaseEncoderTester): + supports_dilated = False + encoder_names = ["xception"] if not RUN_ALL_ENCODERS else ["xception"] diff --git a/tests/encoders/test_smp_encoders.py b/tests/encoders/test_smp_encoders.py new file mode 100644 index 00000000..863537bf --- /dev/null +++ b/tests/encoders/test_smp_encoders.py @@ -0,0 +1,41 @@ +from tests.encoders import base +from tests.utils import RUN_ALL_ENCODERS + + +class TestMobileoneEncoder(base.BaseEncoderTester): + encoder_names = ( + ["mobileone_s0"] + if not RUN_ALL_ENCODERS + else [ + "mobileone_s0", + "mobileone_s1", + "mobileone_s2", + "mobileone_s3", + "mobileone_s4", + ] + ) + + +class TestMixTransformerEncoder(base.BaseEncoderTester): + encoder_names = ( + ["mit_b0"] + if not RUN_ALL_ENCODERS + else ["mit_b0", "mit_b1", "mit_b2", "mit_b3", "mit_b4", "mit_b5"] + ) + + +class TestEfficientNetEncoder(base.BaseEncoderTester): + encoder_names = ( + ["efficientnet-b0"] + if not RUN_ALL_ENCODERS + else [ + "efficientnet-b0", + "efficientnet-b1", + "efficientnet-b2", + "efficientnet-b3", + "efficientnet-b4", + "efficientnet-b5", + "efficientnet-b6", + # "efficientnet-b7", # extra large model + ] + ) diff --git a/tests/encoders/test_timm_ported_encoders.py b/tests/encoders/test_timm_ported_encoders.py new file mode 100644 index 00000000..b467c968 --- /dev/null +++ b/tests/encoders/test_timm_ported_encoders.py @@ -0,0 +1,131 @@ +from tests.encoders import base +from tests.utils import RUN_ALL_ENCODERS + + +class TestTimmEfficientNetEncoder(base.BaseEncoderTester): + encoder_names = ( + ["timm-efficientnet-b0"] + if not RUN_ALL_ENCODERS + else [ + "timm-efficientnet-b0", + "timm-efficientnet-b1", + "timm-efficientnet-b2", + "timm-efficientnet-b3", + "timm-efficientnet-b4", + "timm-efficientnet-b5", + "timm-efficientnet-b6", + "timm-efficientnet-b7", + "timm-efficientnet-b8", + "timm-efficientnet-l2", + "timm-tf_efficientnet_lite0", + "timm-tf_efficientnet_lite1", + "timm-tf_efficientnet_lite2", + "timm-tf_efficientnet_lite3", + "timm-tf_efficientnet_lite4", + ] + ) + + +class TestTimmGERNetEncoder(base.BaseEncoderTester): + encoder_names = ( + ["timm-gernet_s"] + if not RUN_ALL_ENCODERS + else ["timm-gernet_s", "timm-gernet_m", "timm-gernet_l"] + ) + + +class TestTimmMobileNetV3Encoder(base.BaseEncoderTester): + encoder_names = ( + ["timm-mobilenetv3_small_100"] + if not RUN_ALL_ENCODERS + else [ + "timm-mobilenetv3_large_075", + "timm-mobilenetv3_large_100", + "timm-mobilenetv3_large_minimal_100", + "timm-mobilenetv3_small_075", + "timm-mobilenetv3_small_100", + "timm-mobilenetv3_small_minimal_100", + ] + ) + + +class TestTimmRegNetEncoder(base.BaseEncoderTester): + encoder_names = ( + ["timm-regnetx_002", "timm-regnety_002"] + if not RUN_ALL_ENCODERS + else [ + "timm-regnetx_002", + "timm-regnetx_004", + "timm-regnetx_006", + "timm-regnetx_008", + "timm-regnetx_016", + "timm-regnetx_032", + "timm-regnetx_040", + "timm-regnetx_064", + "timm-regnetx_080", + "timm-regnetx_120", + "timm-regnetx_160", + "timm-regnetx_320", + "timm-regnety_002", + "timm-regnety_004", + "timm-regnety_006", + "timm-regnety_008", + "timm-regnety_016", + "timm-regnety_032", + "timm-regnety_040", + "timm-regnety_064", + "timm-regnety_080", + "timm-regnety_120", + "timm-regnety_160", + "timm-regnety_320", + ] + ) + + +class TestTimmRes2NetEncoder(base.BaseEncoderTester): + supports_dilated = False + encoder_names = ( + ["timm-res2net50_26w_4s"] + if not RUN_ALL_ENCODERS + else [ + "timm-res2net50_26w_4s", + "timm-res2net101_26w_4s", + "timm-res2net50_26w_6s", + "timm-res2net50_26w_8s", + "timm-res2net50_48w_2s", + "timm-res2net50_14w_8s", + "timm-res2next50", + ] + ) + + +class TestTimmResnestEncoder(base.BaseEncoderTester): + default_batch_size = 2 + supports_dilated = False + encoder_names = ( + ["timm-resnest14d"] + if not RUN_ALL_ENCODERS + else [ + "timm-resnest14d", + "timm-resnest26d", + "timm-resnest50d", + "timm-resnest101e", + "timm-resnest200e", + "timm-resnest269e", + "timm-resnest50d_4s2x40d", + "timm-resnest50d_1s4x24d", + ] + ) + + +class TestTimmSkNetEncoder(base.BaseEncoderTester): + default_batch_size = 2 + encoder_names = ( + ["timm-skresnet18"] + if not RUN_ALL_ENCODERS + else [ + "timm-skresnet18", + "timm-skresnet34", + "timm-skresnext50_32x4d", + ] + ) diff --git a/tests/encoders/test_timm_universal.py b/tests/encoders/test_timm_universal.py new file mode 100644 index 00000000..753ee4de --- /dev/null +++ b/tests/encoders/test_timm_universal.py @@ -0,0 +1,16 @@ +from tests.encoders import base +from tests.utils import has_timm_test_models + +# check if timm >= 1.0.12 +timm_encoders = [ + "tu-resnet18", # for timm universal traditional-like encoder + "tu-convnext_atto", # for timm universal transformer-like encoder + "tu-darknet17", # for timm universal vgg-like encoder +] + +if has_timm_test_models: + timm_encoders.append("tu-test_resnet.r160_in1k") + + +class TestTimmUniversalEncoder(base.BaseEncoderTester): + encoder_names = timm_encoders diff --git a/tests/encoders/test_torchvision_encoders.py b/tests/encoders/test_torchvision_encoders.py new file mode 100644 index 00000000..99b8b9d5 --- /dev/null +++ b/tests/encoders/test_torchvision_encoders.py @@ -0,0 +1,24 @@ +from tests.encoders import base +from tests.utils import RUN_ALL_ENCODERS + + +class TestMobileoneEncoder(base.BaseEncoderTester): + encoder_names = ["mobilenet_v2"] if not RUN_ALL_ENCODERS else ["mobilenet_v2"] + + +class TestVggEncoder(base.BaseEncoderTester): + supports_dilated = False + encoder_names = ( + ["vgg11"] + if not RUN_ALL_ENCODERS + else [ + "vgg11", + "vgg11_bn", + "vgg13", + "vgg13_bn", + "vgg16", + "vgg16_bn", + "vgg19", + "vgg19_bn", + ] + ) diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models/base.py b/tests/models/base.py new file mode 100644 index 00000000..02e17303 --- /dev/null +++ b/tests/models/base.py @@ -0,0 +1,206 @@ +import os +import pytest +import inspect +import tempfile +import unittest +from functools import lru_cache +from huggingface_hub import hf_hub_download + +import torch +import segmentation_models_pytorch as smp + +from tests.utils import ( + has_timm_test_models, + default_device, + slow_test, + requires_torch_greater_or_equal, +) + + +class BaseModelTester(unittest.TestCase): + test_encoder_name = ( + "tu-test_resnet.r160_in1k" if has_timm_test_models else "resnet18" + ) + + # should be overriden + test_model_type = None + + # test sample configuration + default_batch_size = 1 + default_num_channels = 3 + default_height = 64 + default_width = 64 + + @property + def model_type(self): + if self.test_model_type is None: + raise ValueError("test_model_type is not set") + return self.test_model_type + + @property + def hub_checkpoint(self): + return f"smp-test-models/{self.model_type}-tu-resnet18" + + @property + def model_class(self): + return smp.MODEL_ARCHITECTURES_MAPPING[self.model_type] + + @property + def decoder_channels(self): + signature = inspect.signature(self.model_class) + # check if decoder_channels is in the signature + if "decoder_channels" in signature.parameters: + return signature.parameters["decoder_channels"].default + return None + + @lru_cache + def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32): + return torch.rand(batch_size, num_channels, height, width) + + def test_forward_backward(self): + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height, + width=self.default_width, + ).to(default_device) + model = smp.create_model( + arch=self.model_type, encoder_name=self.test_encoder_name + ).to(default_device) + + # check default in_channels=3 + output = model(sample) + + # check default output number of classes = 1 + expected_number_of_classes = 1 + result_number_of_classes = output.shape[1] + self.assertEqual( + result_number_of_classes, + expected_number_of_classes, + f"Default output number of classes should be {expected_number_of_classes}, but got {result_number_of_classes}", + ) + + # check backward pass + output.mean().backward() + + def test_in_channels_and_depth_and_out_classes( + self, in_channels=1, depth=3, classes=7 + ): + kwargs = {} + + if self.model_type in ["unet", "unetplusplus", "manet"]: + kwargs = {"decoder_channels": self.decoder_channels[:depth]} + + model = smp.create_model( + arch=self.model_type, + encoder_name=self.test_encoder_name, + encoder_depth=depth, + in_channels=in_channels, + classes=classes, + **kwargs, + ).to(default_device) + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=in_channels, + height=self.default_height, + width=self.default_width, + ).to(default_device) + + # check in channels correctly set + with torch.no_grad(): + output = model(sample) + + self.assertEqual(output.shape[1], classes) + + def test_classification_head(self): + model = smp.create_model( + arch=self.model_type, + encoder_name=self.test_encoder_name, + aux_params={ + "pooling": "avg", + "classes": 10, + "dropout": 0.5, + "activation": "sigmoid", + }, + ).to(default_device) + + self.assertIsNotNone(model.classification_head) + self.assertIsInstance(model.classification_head[0], torch.nn.AdaptiveAvgPool2d) + self.assertIsInstance(model.classification_head[1], torch.nn.Flatten) + self.assertIsInstance(model.classification_head[2], torch.nn.Dropout) + self.assertEqual(model.classification_head[2].p, 0.5) + self.assertIsInstance(model.classification_head[3], torch.nn.Linear) + self.assertIsInstance(model.classification_head[4].activation, torch.nn.Sigmoid) + + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height, + width=self.default_width, + ).to(default_device) + + with torch.no_grad(): + _, cls_probs = model(sample) + + self.assertEqual(cls_probs.shape[1], 10) + + @requires_torch_greater_or_equal("2.0.1") + def test_save_load_with_hub_mixin(self): + # instantiate model + model = smp.create_model( + arch=self.model_type, encoder_name=self.test_encoder_name + ).to(default_device) + + # save model + with tempfile.TemporaryDirectory() as tmpdir: + model.save_pretrained( + tmpdir, dataset="test_dataset", metrics={"my_awesome_metric": 0.99} + ) + restored_model = smp.from_pretrained(tmpdir).to(default_device) + with open(os.path.join(tmpdir, "README.md"), "r") as f: + readme = f.read() + + # check inference is correct + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height, + width=self.default_width, + ).to(default_device) + + with torch.no_grad(): + output = model(sample) + restored_output = restored_model(sample) + + self.assertEqual(output.shape, restored_output.shape) + self.assertEqual(output.shape[1], 1) + + # check dataset and metrics are saved in readme + self.assertIn("test_dataset", readme) + self.assertIn("my_awesome_metric", readme) + + @slow_test + @requires_torch_greater_or_equal("2.0.1") + @pytest.mark.logits_match + def test_preserve_forward_output(self): + model = smp.from_pretrained(self.hub_checkpoint).eval().to(default_device) + + input_tensor_path = hf_hub_download( + repo_id=self.hub_checkpoint, filename="input-tensor.pth" + ) + output_tensor_path = hf_hub_download( + repo_id=self.hub_checkpoint, filename="output-tensor.pth" + ) + + input_tensor = torch.load(input_tensor_path, weights_only=True) + input_tensor = input_tensor.to(default_device) + output_tensor = torch.load(output_tensor_path, weights_only=True) + output_tensor = output_tensor.to(default_device) + + with torch.no_grad(): + output = model(input_tensor) + + self.assertEqual(output.shape, output_tensor.shape) + is_close = torch.allclose(output, output_tensor, atol=5e-2) + max_diff = torch.max(torch.abs(output - output_tensor)) + self.assertTrue(is_close, f"Max diff: {max_diff}") diff --git a/tests/models/test_deeplab.py b/tests/models/test_deeplab.py new file mode 100644 index 00000000..d3d350e9 --- /dev/null +++ b/tests/models/test_deeplab.py @@ -0,0 +1,16 @@ +import pytest +from tests.models import base + + +@pytest.mark.deeplabv3 +class TestDeeplabV3Model(base.BaseModelTester): + test_model_type = "deeplabv3" + + default_batch_size = 2 + + +@pytest.mark.deeplabv3plus +class TestDeeplabV3PlusModel(base.BaseModelTester): + test_model_type = "deeplabv3plus" + + default_batch_size = 2 diff --git a/tests/models/test_fpn.py b/tests/models/test_fpn.py new file mode 100644 index 00000000..15ae1f6a --- /dev/null +++ b/tests/models/test_fpn.py @@ -0,0 +1,7 @@ +import pytest +from tests.models import base + + +@pytest.mark.fpn +class TestFpnModel(base.BaseModelTester): + test_model_type = "fpn" diff --git a/tests/models/test_linknet.py b/tests/models/test_linknet.py new file mode 100644 index 00000000..1ab5eb4e --- /dev/null +++ b/tests/models/test_linknet.py @@ -0,0 +1,7 @@ +import pytest +from tests.models import base + + +@pytest.mark.linknet +class TestLinknetModel(base.BaseModelTester): + test_model_type = "linknet" diff --git a/tests/models/test_manet.py b/tests/models/test_manet.py new file mode 100644 index 00000000..33a8ae3b --- /dev/null +++ b/tests/models/test_manet.py @@ -0,0 +1,7 @@ +import pytest +from tests.models import base + + +@pytest.mark.manet +class TestManetModel(base.BaseModelTester): + test_model_type = "manet" diff --git a/tests/models/test_pan.py b/tests/models/test_pan.py new file mode 100644 index 00000000..d66fefe0 --- /dev/null +++ b/tests/models/test_pan.py @@ -0,0 +1,11 @@ +import pytest +from tests.models import base + + +@pytest.mark.pan +class TestPanModel(base.BaseModelTester): + test_model_type = "pan" + + default_batch_size = 2 + default_height = 128 + default_width = 128 diff --git a/tests/models/test_psp.py b/tests/models/test_psp.py new file mode 100644 index 00000000..2603cdda --- /dev/null +++ b/tests/models/test_psp.py @@ -0,0 +1,9 @@ +import pytest +from tests.models import base + + +@pytest.mark.psp +class TestPspModel(base.BaseModelTester): + test_model_type = "pspnet" + + default_batch_size = 2 diff --git a/tests/models/test_segformer.py b/tests/models/test_segformer.py new file mode 100644 index 00000000..3ca5016c --- /dev/null +++ b/tests/models/test_segformer.py @@ -0,0 +1,43 @@ +import torch +import pytest +import segmentation_models_pytorch as smp + +from tests.models import base +from tests.utils import slow_test, default_device, requires_torch_greater_or_equal + + +@pytest.mark.segformer +class TestSegformerModel(base.BaseModelTester): + test_model_type = "segformer" + + @slow_test + @requires_torch_greater_or_equal("2.0.1") + @pytest.mark.logits_match + def test_load_pretrained(self): + hub_checkpoint = "smp-hub/segformer-b0-512x512-ade-160k" + + model = smp.from_pretrained(hub_checkpoint) + model = model.eval().to(default_device) + + sample = torch.ones([1, 3, 512, 512]).to(default_device) + + with torch.no_grad(): + output = model(sample) + + self.assertEqual(output.shape, (1, 150, 512, 512)) + + expected_logits_slice = torch.tensor( + [-4.4172, -4.4723, -4.5273, -4.5824, -4.6375, -4.7157] + ) + resulted_logits_slice = output[0, 0, 256, :6].cpu() + is_equal = torch.allclose( + expected_logits_slice, resulted_logits_slice, atol=1e-2 + ) + max_diff = torch.max(torch.abs(expected_logits_slice - resulted_logits_slice)) + self.assertTrue( + is_equal, + f"Expected logits slice and resulted logits slice are not equal.\n" + f"Max diff: {max_diff}\n" + f"Expected: {expected_logits_slice}\n" + f"Resulted: {resulted_logits_slice}\n", + ) diff --git a/tests/models/test_unet.py b/tests/models/test_unet.py new file mode 100644 index 00000000..54c69bf0 --- /dev/null +++ b/tests/models/test_unet.py @@ -0,0 +1,7 @@ +import pytest +from tests.models import base + + +@pytest.mark.unet +class TestUnetModel(base.BaseModelTester): + test_model_type = "unet" diff --git a/tests/models/test_unetplusplus.py b/tests/models/test_unetplusplus.py new file mode 100644 index 00000000..9e67f2ed --- /dev/null +++ b/tests/models/test_unetplusplus.py @@ -0,0 +1,7 @@ +import pytest +from tests.models import base + + +@pytest.mark.unetplusplus +class TestUnetPlusPlusModel(base.BaseModelTester): + test_model_type = "unetplusplus" diff --git a/tests/models/test_upernet.py b/tests/models/test_upernet.py new file mode 100644 index 00000000..71d703f9 --- /dev/null +++ b/tests/models/test_upernet.py @@ -0,0 +1,8 @@ +import pytest +from tests.models import base + + +@pytest.mark.upernet +class TestUnetModel(base.BaseModelTester): + test_model_type = "upernet" + default_batch_size = 2 diff --git a/tests/test_models.py b/tests/test_models.py deleted file mode 100644 index 460dcdf2..00000000 --- a/tests/test_models.py +++ /dev/null @@ -1,194 +0,0 @@ -import pytest -import torch - -import segmentation_models_pytorch as smp # noqa - - -def get_encoders(): - exclude_encoders = [ - "senet154", - "resnext101_32x16d", - "resnext101_32x32d", - "resnext101_32x48d", - ] - encoders = smp.encoders.get_encoder_names() - encoders = [e for e in encoders if e not in exclude_encoders] - encoders.append("tu-resnet34") # for timm universal traditional-like encoder - encoders.append("tu-convnext_atto") # for timm universal transformer-like encoder - encoders.append("tu-darknet17") # for timm universal vgg-like encoder - return encoders - - -ENCODERS = get_encoders() -DEFAULT_ENCODER = "resnet18" - - -def get_sample(model_class): - if model_class in [ - smp.FPN, - smp.Linknet, - smp.Unet, - smp.UnetPlusPlus, - smp.MAnet, - smp.Segformer, - ]: - sample = torch.ones([1, 3, 64, 64]) - elif model_class in [smp.PAN, smp.DeepLabV3, smp.DeepLabV3Plus]: - sample = torch.ones([2, 3, 128, 128]) - elif model_class in [smp.PSPNet, smp.UPerNet]: - # Batch size 2 needed due to nn.BatchNorm2d not supporting (1, C, 1, 1) input - # from PSPModule pooling in PSPNet/UPerNet. - sample = torch.ones([2, 3, 64, 64]) - else: - raise ValueError("Not supported model class {}".format(model_class)) - return sample - - -def _test_forward(model, sample, test_shape=False): - with torch.no_grad(): - out = model(sample) - if test_shape: - assert out.shape[2:] == sample.shape[2:] - - -def _test_forward_backward(model, sample, test_shape=False): - out = model(sample) - out.mean().backward() - if test_shape: - assert out.shape[2:] == sample.shape[2:] - - -@pytest.mark.parametrize("encoder_name", ENCODERS) -@pytest.mark.parametrize("encoder_depth", [3, 5]) -@pytest.mark.parametrize( - "model_class", - [ - smp.FPN, - smp.PSPNet, - smp.Linknet, - smp.Unet, - smp.UnetPlusPlus, - smp.MAnet, - smp.UPerNet, - smp.Segformer, - ], -) -def test_forward(model_class, encoder_name, encoder_depth, **kwargs): - if ( - model_class is smp.Unet - or model_class is smp.UnetPlusPlus - or model_class is smp.MAnet - ): - kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:] - if model_class in [smp.UnetPlusPlus, smp.Linknet]: - if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"): - return # skip transformer-like model* - if model_class is smp.FPN and encoder_depth != 5: - if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"): - return # skip transformer-like model* - model = model_class( - encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs - ) - sample = get_sample(model_class) - model.eval() - if encoder_depth == 5 and model_class != smp.PSPNet: - test_shape = True - else: - test_shape = False - - _test_forward(model, sample, test_shape) - - -@pytest.mark.parametrize( - "model_class", - [ - smp.PAN, - smp.FPN, - smp.PSPNet, - smp.Linknet, - smp.Unet, - smp.UnetPlusPlus, - smp.MAnet, - smp.DeepLabV3, - smp.DeepLabV3Plus, - smp.UPerNet, - smp.Segformer, - ], -) -def test_forward_backward(model_class): - sample = get_sample(model_class) - model = model_class(DEFAULT_ENCODER, encoder_weights=None) - _test_forward_backward(model, sample) - - -@pytest.mark.parametrize( - "model_class", - [ - smp.PAN, - smp.FPN, - smp.PSPNet, - smp.Linknet, - smp.Unet, - smp.UnetPlusPlus, - smp.MAnet, - smp.DeepLabV3, - smp.DeepLabV3Plus, - smp.UPerNet, - smp.Segformer, - ], -) -def test_aux_output(model_class): - model = model_class( - DEFAULT_ENCODER, encoder_weights=None, aux_params=dict(classes=2) - ) - sample = get_sample(model_class) - label_size = (sample.shape[0], 2) - mask, label = model(sample) - assert label.size() == label_size - - -@pytest.mark.parametrize("upsampling", [2, 4, 8]) -@pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet]) -def test_upsample(model_class, upsampling): - default_upsampling = 4 if model_class is smp.FPN else 8 - model = model_class(DEFAULT_ENCODER, encoder_weights=None, upsampling=upsampling) - sample = get_sample(model_class) - mask = model(sample) - assert mask.size()[-1] / 64 == upsampling / default_upsampling - - -@pytest.mark.parametrize("model_class", [smp.FPN]) -@pytest.mark.parametrize("in_channels", [1, 2, 4]) -def test_in_channels(model_class, in_channels): - sample = torch.ones([1, in_channels, 64, 64]) - model = model_class(DEFAULT_ENCODER, encoder_weights=None, in_channels=in_channels) - model.eval() - with torch.no_grad(): - model(sample) - - assert model.encoder._in_channels == in_channels - - -@pytest.mark.parametrize("encoder_name", ENCODERS) -def test_dilation(encoder_name): - if ( - encoder_name in ["inceptionresnetv2", "xception", "inceptionv4"] - or encoder_name.startswith("vgg") - or encoder_name.startswith("densenet") - or encoder_name.startswith("timm-res") - ): - return - - encoder = smp.encoders.get_encoder(encoder_name, output_stride=16) - - encoder.eval() - with torch.no_grad(): - sample = torch.ones([1, 3, 64, 64]) - output = encoder(sample) - - shapes = [out.shape[-1] for out in output] - assert shapes == [64, 32, 16, 8, 4, 4] # last downsampling replaced with dilation - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/utils.py b/tests/utils.py new file mode 100644 index 00000000..e8bce88e --- /dev/null +++ b/tests/utils.py @@ -0,0 +1,47 @@ +import os +import timm +import torch +import unittest + +from packaging.version import Version + + +has_timm_test_models = Version(timm.__version__) >= Version("1.0.12") +default_device = "cuda" if torch.cuda.is_available() else "cpu" + + +def get_commit_message(): + commit_msg = os.getenv("COMMIT_MESSAGE", "") + return commit_msg.lower() + + +# Check both environment variables and commit message +commit_message = get_commit_message() +RUN_ALL_ENCODERS = ( + os.getenv("RUN_ALL_ENCODERS", "false").lower() in ["true", "1", "y", "yes"] + or "run-all-encoders" in commit_message +) + +RUN_SLOW = ( + os.getenv("RUN_SLOW", "false").lower() in ["true", "1", "y", "yes"] + or "run-slow" in commit_message +) + + +def slow_test(test_case): + """ + Decorator marking a test as slow. + + Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. + + """ + return unittest.skipUnless(RUN_SLOW, "test is slow")(test_case) + + +def requires_torch_greater_or_equal(version: str): + torch_version = Version(torch.__version__) + provided_version = Version(version) + return unittest.skipUnless( + torch_version >= provided_version, + f"torch version {torch_version} is less than {provided_version}", + )