Skip to content

Commit

Permalink
Make strict raise error in transforms (#2314)
Browse files Browse the repository at this point in the history
* Temp commit

* Added option to pass strict to transform
  • Loading branch information
ternaus authored Jan 28, 2025
1 parent 8008ca2 commit 7d95c93
Show file tree
Hide file tree
Showing 11 changed files with 120 additions and 38 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ repos:
- id: python-use-type-annotations
- id: text-unicode-replacement-char
- repo: https://github.com/codespell-project/codespell
rev: v2.4.0
rev: v2.4.1
hooks:
- id: codespell
additional_dependencies: ["tomli"]
Expand Down
5 changes: 4 additions & 1 deletion albumentations/core/transforms_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self, downscale: int = cv2.INTER_NEAREST, upscale: int = cv2.INTER_
class BaseTransformInitSchema(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
p: ProbabilityType
strict: bool = False


class CombinedMeta(SerializableMeta, ValidatedTransformMeta):
Expand Down Expand Up @@ -65,7 +66,6 @@ class InitSchema(BaseTransformInitSchema):
def __init__(self, p: float = 0.5):
self.p = p
self._additional_targets: dict[str, str] = {}
# replay mode params
self.params: dict[Any, Any] = {}
self._key2func = {}
self._set_keys()
Expand Down Expand Up @@ -349,9 +349,12 @@ def get_transform_init_args(self) -> dict[str, Any]:
return args

def to_dict_private(self) -> dict[str, Any]:
"""Returns a dictionary representation of the transform, excluding internal parameters."""
state = {"__class_fullname__": self.get_class_fullname()}
state.update(self.get_base_init_args())
state.update(self.get_transform_init_args())
# Remove strict from serialization
state.pop("strict", None)
return state


Expand Down
36 changes: 26 additions & 10 deletions albumentations/core/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ def custom_init(self: Any, *args: Any, **kwargs: Any) -> None:
full_kwargs: dict[str, Any] = dict(zip(param_names, args))
full_kwargs.update(kwargs)

# Get strict value before validation
strict = full_kwargs.pop("strict", False) # Remove strict from kwargs

for parameter_name, parameter in init_params.items():
if (
parameter_name != "self"
Expand All @@ -31,24 +34,37 @@ def custom_init(self: Any, *args: Any, **kwargs: Any) -> None:
):
full_kwargs[parameter_name] = parameter.default

# No try-except block needed as we want the exception to propagate naturally
config = dct["InitSchema"](**full_kwargs)
# Configure model validation
try:
config = dct["InitSchema"](**{k: v for k, v in full_kwargs.items() if k in param_names})
validated_kwargs = config.model_dump()
# Remove strict from validated kwargs to prevent it from being passed to __init__
validated_kwargs.pop("strict", None)
except Exception as e:
if strict:
raise
warn(str(e), stacklevel=2)
# Use default values for invalid parameters
config = dct["InitSchema"]()
validated_kwargs = config.model_dump()
validated_kwargs.pop("strict", None) # Also remove from default values

validated_kwargs = config.model_dump()
for name_arg in kwargs:
if name_arg not in validated_kwargs:
warn(
f"Argument '{name_arg}' is not valid and will be ignored.",
stacklevel=2,
)
invalid_args = [
name_arg for name_arg in kwargs if name_arg not in validated_kwargs and name_arg != "strict"
]
if invalid_args:
message = f"Argument(s) '{', '.join(invalid_args)}' are not valid for transform {name}"
if strict:
raise ValueError(message)
warn(message, stacklevel=2)

# Call original init with validated kwargs (strict removed)
original_init(self, **validated_kwargs)

# Preserve the original signature and docstring
custom_init.__signature__ = original_sig # type: ignore[attr-defined]
custom_init.__doc__ = original_init.__doc__

# Rename __init__ to custom_init to avoid the N807 warning
dct["__init__"] = custom_init

return super().__new__(cls, name, bases, dct)
2 changes: 1 addition & 1 deletion tests/aug_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,5 +413,5 @@
[A.CubicSymmetry, {}],
[A.AtLeastOneBBoxRandomCrop, {"height": 80, "width": 80, "erosion_factor": 0.2}],
[A.ConstrainedCoarseDropout, {"num_holes_range": (1, 3), "hole_height_range": (0.1, 0.2), "hole_width_range": (0.1, 0.2), "fill": 0, "fill_mask": 0, "mask_indices": [1]}],
[A.RandomSizedBBoxSafeCrop, {"height": 80, "width": 80, "erosion_factor": 0.2}],
[A.RandomSizedBBoxSafeCrop, {"height": 80, "width": 80, "erosion_rate": 0.2}],
]
2 changes: 1 addition & 1 deletion tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def test_image_only_crop_around_bbox_augmentation(augmentation_cls, params, imag
[A.Rotate, {"border_mode": cv2.BORDER_CONSTANT, "fill": 100, "fill_mask": 1}],
[A.SafeRotate, {"border_mode": cv2.BORDER_CONSTANT, "fill": 100, "fill_mask": 1}],
[A.ShiftScaleRotate, {"border_mode": cv2.BORDER_CONSTANT, "fill": 100, "fill_mask": 1}],
[A.Affine, {"mode": cv2.BORDER_CONSTANT, "fill_mask": 1, "fill": 100}],
[A.Affine, {"border_mode": cv2.BORDER_CONSTANT, "fill_mask": 1, "fill": 100}],
],
)
def test_mask_fill_value(augmentation_cls, params):
Expand Down
2 changes: 1 addition & 1 deletion tests/test_bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,7 @@ def test_bounding_box_vflip(bbox, expected_bbox) -> None:
@pytest.mark.parametrize(
"get_transform",
[
lambda sign: A.Affine(translate_px=sign * 2, mode=cv2.BORDER_CONSTANT, fill=255),
lambda sign: A.Affine(translate_px=sign * 2, border_mode=cv2.BORDER_CONSTANT, fill=255),
lambda sign: A.ShiftScaleRotate(
shift_limit=(sign * 0.02, sign * 0.02),
scale_limit=0,
Expand Down
66 changes: 62 additions & 4 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections.abc import Callable
import typing
from unittest import mock
from unittest.mock import MagicMock, Mock, call, patch
import warnings
import torch
import cv2
import numpy as np
Expand All @@ -20,7 +22,7 @@
Sequential,
SomeOf,
)
from albumentations.core.transforms_interface import DualTransform, ImageOnlyTransform, NoOp
from albumentations.core.transforms_interface import BasicTransform, DualTransform, ImageOnlyTransform, NoOp
from albumentations.core.utils import to_tuple, get_shape
from tests.conftest import (
IMAGES,
Expand Down Expand Up @@ -458,7 +460,7 @@ def test_check_each_transform_compose(targets, bbox_params, keypoint_params, exp
image = np.empty([100, 100], dtype=np.uint8)

augs = Compose(
[Compose([A.Crop(0, 0, 50, 50), A.PadIfNeeded(100, 100, border_mode=cv2.BORDER_CONSTANT, value=0)])],
[Compose([A.Crop(0, 0, 50, 50), A.PadIfNeeded(100, 100, border_mode=cv2.BORDER_CONSTANT, fill=0)])],
bbox_params=bbox_params,
keypoint_params=keypoint_params,
seed=137
Expand Down Expand Up @@ -548,7 +550,7 @@ def test_check_each_transform_sequential(targets, bbox_params, keypoint_params,
image = np.empty([100, 100], dtype=np.uint8)

augs = Compose(
[Sequential([A.Crop(0, 0, 50, 50), A.PadIfNeeded(100, 100, border_mode=cv2.BORDER_CONSTANT, value=0)], p=1.0)],
[Sequential([A.Crop(0, 0, 50, 50), A.PadIfNeeded(100, 100, border_mode=cv2.BORDER_CONSTANT, fill=0)], p=1.0)],
bbox_params=bbox_params,
keypoint_params=keypoint_params,
)
Expand Down Expand Up @@ -639,7 +641,7 @@ def test_check_each_transform_someof(targets, bbox_params, keypoint_params, expe
augs = Compose(
[
SomeOf([A.Crop(0, 0, 50, 50)], n=1, replace=False, p=1.0),
SomeOf([A.PadIfNeeded(100, 100, border_mode=cv2.BORDER_CONSTANT, value=0)], n=1, replace=False, p=1.0),
SomeOf([A.PadIfNeeded(100, 100, border_mode=cv2.BORDER_CONSTANT, fill=0)], n=1, replace=False, p=1.0),
],
bbox_params=bbox_params,
keypoint_params=keypoint_params,
Expand Down Expand Up @@ -1265,6 +1267,31 @@ def test_masks_as_target(augmentation_cls, params, masks):
A.ConstrainedCoarseDropout,
A.PadIfNeeded,
A.RandomRotate90,
A.D4,
A.GridDistortion,
A.ElasticTransform,
A.GridElasticDeform,
A.HorizontalFlip,
A.VerticalFlip,
A.Transpose,
A.LongestMaxSize,
A.SmallestMaxSize,
A.RandomGridShuffle,
A.Morphological,
A.NoOp,
A.OpticalDistortion,
A.Pad,
A.PiecewiseAffine,
A.RandomScale,
A.RandomSizedBBoxSafeCrop,
A.RandomSizedCrop,
A.RandomResizedCrop,
A.RandomRotate90,
A.RandomCropFromBorders,
A.Resize,
A.ThinPlateSpline,
A.TimeReverse,
A.TimeMasking
},
),
)
Expand Down Expand Up @@ -1550,3 +1577,34 @@ def test_get_shape_empty_arrays(key):
shape = get_shape(data)
assert isinstance(shape, dict)
assert all(isinstance(v, int) for v in shape.values())


def test_transform_strict_mode_raises_error():
# Test that strict=True raises error for invalid parameters
with pytest.raises(ValueError, match="Argument\\(s\\) 'invalid_param' are not valid for transform Blur"):
A.Blur(strict=True, invalid_param=123)

def test_transform_non_strict_mode_shows_warning():
# Test that strict=False (default) shows warning for invalid parameters
with pytest.warns(UserWarning, match="Argument\\(s\\) 'invalid_param' are not valid for transform Blur"):
transform = A.Blur(invalid_param=123)
assert transform.p == 0.5 # Check that transform was still created with default values

def test_transform_valid_params_no_warning():
# Test that no warning/error is raised for valid parameters
with warnings.catch_warnings():
warnings.simplefilter("error") # Convert warnings to errors to ensure none are raised
transform = A.Blur(p=0.7, blur_limit=(3, 5))
assert transform.p == 0.7
assert transform.blur_limit == (3, 5)

def test_transform_multiple_invalid_params():
# Test handling of multiple invalid parameters
with pytest.raises(ValueError, match="Argument\\(s\\) 'invalid1, invalid2' are not valid for transform Blur"):
A.Blur(strict=True, invalid1=123, invalid2=456)

def test_transform_strict_with_valid_params():
# Test that strict mode doesn't affect valid parameters
transform = A.Blur(strict=True, p=0.7, blur_limit=(3, 5))
assert transform.p == 0.7
assert transform.blur_limit == (3, 5)
12 changes: 0 additions & 12 deletions tests/test_pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,18 +335,6 @@ def test_custom_image_transform_signature() -> None:
assert expected_params["custom_param"].annotation is int


def test_wrong_argument() -> None:
"""Test that pas Transform will get warning"""
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
transform = A.Blur(wrong_param=10)
assert not hasattr(transform, "wrong_param")
assert len(w) == 1
assert issubclass(w[0].category, UserWarning)
assert str(w[0].message) == "Argument 'wrong_param' is not valid and will be ignored."
warnings.resetwarnings()


def test_check_range_bounds_doctest():
# Test the examples from the docstring
validator = check_range_bounds(0, 1)
Expand Down
15 changes: 15 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,18 @@ def get_all_init_schema_fields(model_cls: A.BasicTransform) -> Set[str]:
assert reported_args.issubset(
expected_args
), f"Mismatch in {augmentation_cls.__name__}: Serialized fields {reported_args} not a subset of schema fields {expected_args}"



def test_serialization_excludes_strict() -> None:
# Test that strict parameter is not included in serialization
transform = A.Compose([A.HorizontalFlip()])
transform_dict = A.to_dict(transform)["transform"]
assert "strict" not in transform_dict
# Also check nested transforms
assert "strict" not in transform_dict["transforms"][0]

# Test individual transform serialization
transform = A.HorizontalFlip(strict=True)
transform_dict = A.to_dict(transform)["transform"]
assert "strict" not in transform_dict
14 changes: 8 additions & 6 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,9 @@ def test_rotate_crop_border():
)
def test_binary_mask_interpolation(augmentation_cls, params):
"""Checks whether transformations based on DualTransform does not introduce a mask interpolation artifacts"""
params["fill_mask"] = 0
params["mask_interpolation"] = cv2.INTER_NEAREST
params["fill_mask"] = 0

aug = augmentation_cls(p=1, **params)
image = SQUARE_UINT8_IMAGE
mask = np.random.randint(low=0, high=2, size=(100, 100), dtype=np.uint8)
Expand All @@ -97,7 +98,7 @@ def test_binary_mask_interpolation(augmentation_cls, params):
["augmentation_cls", "params"],
get_dual_transforms(
custom_arguments={
A.GridDropout: {"num_grid_xy": (10, 10), "fill_mask": 64},
A.GridDropout: {"holes_number_xy": (10, 10), "fill_mask": 64},
A.TemplateTransform: {
"templates": np.random.randint(
low=0, high=256, size=(100, 100, 3), dtype=np.uint8
Expand Down Expand Up @@ -139,13 +140,14 @@ def test_binary_mask_interpolation(augmentation_cls, params):
def test_semantic_mask_interpolation(augmentation_cls, params, image):
"""Checks whether transformations based on DualTransform does not introduce a mask interpolation artifacts."""

np.random.seed(42)
mask = np.random.randint(low=0, high=4, size=(100, 100), dtype=np.uint8) * 64

seed = 137
params["mask_interpolation"] = cv2.INTER_NEAREST
params["fill_mask"] = 0

data = A.Compose([augmentation_cls(p=1, **params)], seed=42)(image=image, mask=mask)
np.random.seed(seed)
mask = np.random.randint(low=0, high=4, size=(100, 100), dtype=np.uint8) * 64

data = A.Compose([augmentation_cls(p=1, **params)], seed=seed)(image=image, mask=mask)

np.testing.assert_array_equal(np.unique(data["mask"]), np.array([0, 64, 128, 192]))

Expand Down
2 changes: 1 addition & 1 deletion tests/transforms3d/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,7 @@ def test_image_volume_matching(image, augmentation_cls, params):
)
def test_keypoints_xy_xyz(augmentation_cls, params):
"""Test that xy and xyz keypoint formats produce identical results for x,y coordinates."""
seed = 42
seed = 137
aug1 = A.Compose([augmentation_cls(**params, p=1)], seed=seed, keypoint_params={"format": "xy"})
aug2 = A.Compose([augmentation_cls(**params, p=1)], seed=seed, keypoint_params={"format": "xyz"})

Expand Down

0 comments on commit 7d95c93

Please sign in to comment.