Skip to content

Commit

Permalink
Refactor tests (#1011)
Browse files Browse the repository at this point in the history
* Add parallel test deps

* Update signature

* Add encoders tests

* Update gitignore

* Update encoders for timm-universal

* Add parallel tests run

* Disable models tests

* Add uv to CI

* Add uv to minimum

* Add show-install-packages

* Increase to 3 workers

* Fix show-packages

* Change back for 2 workers

* Add coverage

* Basic model test

* Fix

* Move model archs

* Add base params test

* Fix timm test for minimum version

* Remove deprecated utils from coverage

* Fix

* Fix

* Exclude conversion script

* Add save-load test, add aux head test

* Remove custom encoder

* Set encoder for models tests

* Docs + flag for anyres

* Fix loading from config

* Bump min hf-hub to 0.25.0

* Fix minimal

* Add test with hub checkpoint

* Fixing minimum

* Fix

* Fix torch for minimum tests

* Update torch version and run-slow

* run-slow

* Show skipped

* [run-slow] Fixing minimum

* [run-slow] Fixing minimum

* Fix decorator

* Raise error

* [run-slow] Fixing run slow

* [run-slow] Fixing run slow

* Run slow tests in separate job

* FIx

* Fixes

* Add device

* Bum tolerance

* Add device

* Fixup
  • Loading branch information
qubvel authored Dec 23, 2024
1 parent fbeeb0c commit 900ac49
Show file tree
Hide file tree
Showing 42 changed files with 1,046 additions and 223 deletions.
49 changes: 43 additions & 6 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,24 +36,61 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install dependencies
run: python -m pip install -r requirements/required.txt -r requirements/test.txt
- name: Test with pytest
run: pytest
run: |
python -m pip install uv
python -m uv pip install --system -r requirements/required.txt -r requirements/test.txt
- name: Show installed packages
run: |
python -m pip list
- name: Test with PyTest
run: |
pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml -k "not logits_match"
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
slug: qubvel-org/segmentation_models.pytorch
if: matrix.os == 'macos-latest' && matrix.python-version == '3.12'

test_logits_match:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install dependencies
run: |
python -m pip install uv
python -m uv pip install --system -r requirements/required.txt -r requirements/test.txt
- name: Test with PyTest
run: RUN_SLOW=1 pytest -v -rsx -n 2 -k "logits_match"

minimum:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.9"
- name: Install dependencies
run: python -m pip install -r requirements/minimum.old -r requirements/test.txt
run: |
python -m pip install uv
python -m uv pip install --system -r requirements/minimum.old -r requirements/test.txt
- name: Show installed packages
run: |
python -m pip list
- name: Test with pytest
run: pytest
run: pytest -v -rsx -n 2 -k "not logits_match"
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ venv/
ENV/
env.bak/
venv.bak/
.vscode/

# Spyder project settings
.spyderproject
Expand Down
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ install_dev: .venv
.venv/bin/pip install -e ".[test]"

test: .venv
.venv/bin/pytest -p no:cacheprovider tests/
.venv/bin/pytest -v -rsx -n 2 tests/ -k "not logits_match"

test_all: .venv
RUN_SLOW=1 .venv/bin/pytest -v -rsx -n 2 tests/

table:
.venv/bin/python misc/generate_table.py
Expand Down
41 changes: 41 additions & 0 deletions misc/generate_test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import os
import torch
import tempfile
import huggingface_hub
import segmentation_models_pytorch as smp

HUB_REPO = "smp-test-models"
ENCODER_NAME = "tu-resnet18"

api = huggingface_hub.HfApi(token=os.getenv("HF_TOKEN"))

for model_name, model_class in smp.MODEL_ARCHITECTURES_MAPPING.items():
model = model_class(encoder_name=ENCODER_NAME)
model = model.eval()

# generate test sample
torch.manual_seed(423553)
sample = torch.rand(1, 3, 256, 256)

with torch.no_grad():
output = model(sample)

with tempfile.TemporaryDirectory() as tmpdir:
# save model
model.save_pretrained(f"{tmpdir}")

# save input and output
torch.save(sample, f"{tmpdir}/input-tensor.pth")
torch.save(output, f"{tmpdir}/output-tensor.pth")

# create repo
repo_id = f"{HUB_REPO}/{model_name}-{ENCODER_NAME}"
if not api.repo_exists(repo_id=repo_id):
api.create_repo(repo_id=repo_id, repo_type="model")

# upload to hub
api.upload_folder(
folder_path=tmpdir,
repo_id=f"{HUB_REPO}/{model_name}-{ENCODER_NAME}",
repo_type="model",
)
24 changes: 24 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ docs = [
]
test = [
'pytest',
'pytest-cov',
'pytest-xdist',
'ruff',
]

Expand All @@ -55,3 +57,25 @@ version = {attr = 'segmentation_models_pytorch.__version__.__version__'}

[tool.setuptools.packages.find]
include = ['segmentation_models_pytorch*']

[tool.pytest.ini_options]
markers = [
"deeplabv3",
"deeplabv3plus",
"fpn",
"linknet",
"manet",
"pan",
"psp",
"segformer",
"unet",
"unetplusplus",
"upernet",
"logits_match",
]

[tool.coverage.run]
omit = [
"segmentation_models_pytorch/utils/*",
"**/convert_*",
]
1 change: 1 addition & 0 deletions requirements/minimum.old
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ timm==0.9.0
torch==1.9.0
torchvision==0.10.0
tqdm==4.42.1
Jinja2==3.0.0
4 changes: 3 additions & 1 deletion requirements/test.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
pytest==8.3.4
ruff==0.8.4
pytest-xdist==3.6.1
pytest-cov==6.0.0
ruff==0.8.4
33 changes: 17 additions & 16 deletions segmentation_models_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,21 @@
"ignore", message=r'"is" with \'str\' literal.*', category=SyntaxWarning
) # for python >= 3.12

_MODEL_ARCHITECTURES = [
Unet,
UnetPlusPlus,
MAnet,
Linknet,
FPN,
PSPNet,
DeepLabV3,
DeepLabV3Plus,
PAN,
UPerNet,
Segformer,
]
MODEL_ARCHITECTURES_MAPPING = {a.__name__.lower(): a for a in _MODEL_ARCHITECTURES}


def create_model(
arch: str,
Expand All @@ -43,26 +58,12 @@ def create_model(
parameters, without using its class
"""

archs = [
Unet,
UnetPlusPlus,
MAnet,
Linknet,
FPN,
PSPNet,
DeepLabV3,
DeepLabV3Plus,
PAN,
UPerNet,
Segformer,
]
archs_dict = {a.__name__.lower(): a for a in archs}
try:
model_class = archs_dict[arch.lower()]
model_class = MODEL_ARCHITECTURES_MAPPING[arch.lower()]
except KeyError:
raise KeyError(
"Wrong architecture type `{}`. Available options are: {}".format(
arch, list(archs_dict.keys())
arch, list(MODEL_ARCHITECTURES_MAPPING.keys())
)
)
return model_class(
Expand Down
11 changes: 11 additions & 0 deletions segmentation_models_pytorch/base/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,3 +136,14 @@ def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs):

model_class = getattr(smp, model_class_name)
return model_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)


def supports_config_loading(func):
"""Decorator to filter special config kwargs"""

@wraps(func)
def wrapper(self, *args, **kwargs):
kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
return func(self, *args, **kwargs)

return wrapper
11 changes: 10 additions & 1 deletion segmentation_models_pytorch/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,22 @@


class SegmentationModel(torch.nn.Module, SMPHubMixin):
"""Base class for all segmentation models."""

# if model supports shape not divisible by 2 ^ n
# set to False
requires_divisible_input_shape = True

def initialize(self):
init.initialize_decoder(self.decoder)
init.initialize_head(self.segmentation_head)
if self.classification_head is not None:
init.initialize_head(self.classification_head)

def check_input_shape(self, x):
"""Check if the input shape is divisible by the output stride.
If not, raise a RuntimeError.
"""
h, w = x.shape[-2:]
output_stride = self.encoder.output_stride
if h % output_stride != 0 or w % output_stride != 0:
Expand All @@ -33,7 +42,7 @@ def check_input_shape(self, x):
def forward(self, x):
"""Sequentially pass `x` trough model`s encoder, decoder and heads"""

if not torch.jit.is_tracing():
if not torch.jit.is_tracing() or self.requires_divisible_input_shape:
self.check_input_shape(x)

features = self.encoder(x)
Expand Down
3 changes: 3 additions & 0 deletions segmentation_models_pytorch/decoders/deeplabv3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading

from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder

Expand Down Expand Up @@ -54,6 +55,7 @@ class DeepLabV3(SegmentationModel):
"""

@supports_config_loading
def __init__(
self,
encoder_name: str = "resnet34",
Expand Down Expand Up @@ -163,6 +165,7 @@ class DeepLabV3Plus(SegmentationModel):
"""

@supports_config_loading
def __init__(
self,
encoder_name: str = "resnet34",
Expand Down
2 changes: 2 additions & 0 deletions segmentation_models_pytorch/decoders/fpn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading

from .decoder import FPNDecoder

Expand Down Expand Up @@ -51,6 +52,7 @@ class FPN(SegmentationModel):
"""

@supports_config_loading
def __init__(
self,
encoder_name: str = "resnet34",
Expand Down
2 changes: 2 additions & 0 deletions segmentation_models_pytorch/decoders/linknet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading

from .decoder import LinknetDecoder

Expand Down Expand Up @@ -53,6 +54,7 @@ class Linknet(SegmentationModel):
https://arxiv.org/abs/1707.03718
"""

@supports_config_loading
def __init__(
self,
encoder_name: str = "resnet34",
Expand Down
2 changes: 2 additions & 0 deletions segmentation_models_pytorch/decoders/manet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading

from .decoder import MAnetDecoder

Expand Down Expand Up @@ -56,6 +57,7 @@ class MAnet(SegmentationModel):
"""

@supports_config_loading
def __init__(
self,
encoder_name: str = "resnet34",
Expand Down
2 changes: 2 additions & 0 deletions segmentation_models_pytorch/decoders/pan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading

from .decoder import PANDecoder

Expand Down Expand Up @@ -53,6 +54,7 @@ class PAN(SegmentationModel):
"""

@supports_config_loading
def __init__(
self,
encoder_name: str = "resnet34",
Expand Down
2 changes: 2 additions & 0 deletions segmentation_models_pytorch/decoders/pspnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading

from .decoder import PSPDecoder

Expand Down Expand Up @@ -54,6 +55,7 @@ class PSPNet(SegmentationModel):
https://arxiv.org/abs/1612.01105
"""

@supports_config_loading
def __init__(
self,
encoder_name: str = "resnet34",
Expand Down
2 changes: 2 additions & 0 deletions segmentation_models_pytorch/decoders/segformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
SegmentationModel,
)
from segmentation_models_pytorch.encoders import get_encoder
from segmentation_models_pytorch.base.hub_mixin import supports_config_loading

from .decoder import SegformerDecoder

Expand Down Expand Up @@ -46,6 +47,7 @@ class Segformer(SegmentationModel):
"""

@supports_config_loading
def __init__(
self,
encoder_name: str = "resnet34",
Expand Down
Loading

0 comments on commit 900ac49

Please sign in to comment.