From fcda65b5a0a2e695c75529e70223d27d88ad5b9e Mon Sep 17 00:00:00 2001 From: Vladimir Iglovikov Date: Wed, 29 Jan 2025 09:07:12 -0800 Subject: [PATCH] Update autocontrast (#2317) * Added PIL method to autocontrast * Speed up in AutoContrast * Speed up in AutoContrast --- albumentations/augmentations/functional.py | 156 ++++++++++--- albumentations/augmentations/transforms.py | 75 ++++++- tests/aug_definitions.py | 5 +- tests/functional/test_functional.py | 249 ++++++++++++++++++++- 4 files changed, 442 insertions(+), 43 deletions(-) diff --git a/albumentations/augmentations/functional.py b/albumentations/augmentations/functional.py index 726663c7d..b3bed9e93 100644 --- a/albumentations/augmentations/functional.py +++ b/albumentations/augmentations/functional.py @@ -2685,55 +2685,63 @@ def apply_gaussian_illumination( @uint8_io -def auto_contrast(img: np.ndarray) -> np.ndarray: +def auto_contrast( + img: np.ndarray, + cutoff: float, + ignore: int | None, + method: Literal["cdf", "pil"], +) -> np.ndarray: """Apply auto contrast to the image. - Auto contrast enhances image contrast by stretching the intensity range - to use the full range while preserving relative intensities. - Args: img: Input image in uint8 or float32 format. + cutoff: Percentage of pixels to cut off from the histogram edges. + Range: 0-100. Default: 0 (no cutoff) + ignore: Pixel value to ignore in auto contrast calculation. + Useful for handling alpha channels or other special values. + method: Method to use for contrast enhancement: + - "cdf": Uses cumulative distribution function (original albumentations method) + - "pil": Uses linear scaling like PIL.ImageOps.autocontrast Returns: Contrast-enhanced image in the same dtype as input. - - Note: - The function: - 1. Computes histogram for each channel - 2. Creates cumulative distribution - 3. Normalizes to full intensity range - 4. Uses lookup table for scaling """ result = img.copy() num_channels = get_num_channels(img) max_value = MAX_VALUES_BY_DTYPE[img.dtype] - for i in range(num_channels): - channel = img[..., i] if img.ndim > MONO_CHANNEL_DIMENSIONS else img - - # Compute histogram - hist = np.histogram(channel.flatten(), bins=256, range=(0, max_value))[0] - - # Calculate cumulative distribution - cdf = hist.cumsum() + # Pre-compute histograms using cv2.calcHist - much faster than np.histogram + if img.ndim > MONO_CHANNEL_DIMENSIONS: + channels = cv2.split(img) + hists: list[np.ndarray] = [] + for i, channel in enumerate(channels): + if ignore is not None and i == ignore: + hists.append(None) + continue + mask = None if ignore is None else (channel != ignore) + hist = cv2.calcHist([channel], [0], mask, [256], [0, max_value]) + hists.append(hist.ravel()) - # Find the minimum and maximum non-zero values in the CDF - if cdf[cdf > 0].size == 0: - continue # Skip if the channel is constant or empty + for i in range(num_channels): + if ignore is not None and i == ignore: + continue - cdf_min = cdf[cdf > 0].min() - cdf_max = cdf.max() + if img.ndim > MONO_CHANNEL_DIMENSIONS: + hist = hists[i] + channel = channels[i] + else: + mask = None if ignore is None else (img != ignore) + hist = cv2.calcHist([img], [0], mask, [256], [0, max_value]).ravel() + channel = img - if cdf_min == cdf_max: + lo, hi = get_histogram_bounds(hist, cutoff) + if hi <= lo: continue - # Normalize CDF - cdf = (cdf - cdf_min) * max_value / (cdf_max - cdf_min) + lut = create_contrast_lut(hist, lo, hi, max_value, method) + if ignore is not None: + lut[ignore] = ignore - # Create lookup table - lut = np.clip(np.around(cdf), 0, max_value).astype(np.uint8) - - # Apply lookup table if img.ndim > MONO_CHANNEL_DIMENSIONS: result[..., i] = sz_lut(channel, lut) else: @@ -2742,6 +2750,92 @@ def auto_contrast(img: np.ndarray) -> np.ndarray: return result +def create_contrast_lut( + hist: np.ndarray, + min_intensity: int, + max_intensity: int, + max_value: int, + method: Literal["cdf", "pil"], +) -> np.ndarray: + """Create lookup table for contrast adjustment.""" + # Handle single intensity case + if min_intensity >= max_intensity: + return np.zeros(256, dtype=np.uint8) + + if method == "cdf": + hist_range = hist[min_intensity : max_intensity + 1] + cdf = hist_range.cumsum() + + if cdf[-1] == 0: # No valid pixels + return np.arange(256, dtype=np.uint8) + + # Normalize CDF to full range + cdf = (cdf - cdf[0]) * max_value / (cdf[-1] - cdf[0]) + + # Create lookup table + lut = np.zeros(256, dtype=np.uint8) + lut[min_intensity : max_intensity + 1] = np.clip(np.round(cdf), 0, max_value).astype(np.uint8) + lut[max_intensity + 1 :] = max_value + return lut + + # "pil" method + scale = max_value / (max_intensity - min_intensity) + indices = np.arange(256, dtype=float) + # Changed: Use np.round to get 128 for middle value + # Test expects [0, 128, 255] for range [0, 2] + lut = np.clip(np.round((indices - min_intensity) * scale), 0, max_value).astype(np.uint8) + lut[:min_intensity] = 0 + lut[max_intensity + 1 :] = max_value + return lut + + +def get_histogram_bounds(hist: np.ndarray, cutoff: float) -> tuple[int, int]: + """Find the low and high bounds of the histogram.""" + if not cutoff: + non_zero_intensities = np.nonzero(hist)[0] + if len(non_zero_intensities) == 0: + return 0, 0 + return int(non_zero_intensities[0]), int(non_zero_intensities[-1]) + + total_pixels = float(hist.sum()) + if total_pixels == 0: + return 0, 0 + + pixels_to_cut = total_pixels * cutoff / 100.0 + + # Special case for uniform 256-bin histogram + if len(hist) == 256 and np.all(hist == hist[0]): + min_intensity = int(len(hist) * cutoff / 100) # floor division + max_intensity = len(hist) - min_intensity - 1 + return min_intensity, max_intensity + + # Find minimum intensity + cumsum = 0.0 + min_intensity = 0 + for i in range(len(hist)): + cumsum += hist[i] + if cumsum >= pixels_to_cut: # Use >= for left bound + min_intensity = i + 1 + break + min_intensity = min(min_intensity, len(hist) - 1) + + # Find maximum intensity + cumsum = 0.0 + max_intensity = len(hist) - 1 + for i in range(len(hist) - 1, -1, -1): + cumsum += hist[i] + if cumsum >= pixels_to_cut: # Use >= for right bound + max_intensity = i + break + + # Handle edge cases + if min_intensity > max_intensity: + mid_point = (len(hist) - 1) // 2 + return mid_point, mid_point + + return min_intensity, max_intensity + + def get_drop_mask( shape: tuple[int, ...], per_channel: bool, diff --git a/albumentations/augmentations/transforms.py b/albumentations/augmentations/transforms.py index a69437dfd..8095de7c7 100644 --- a/albumentations/augmentations/transforms.py +++ b/albumentations/augmentations/transforms.py @@ -6397,33 +6397,88 @@ def get_transform_init_args_names(self) -> tuple[str, ...]: class AutoContrast(ImageOnlyTransform): - """Apply random auto contrast to images. + """Automatically adjust image contrast by stretching the intensity range. - Auto contrast enhances image contrast by stretching the intensity range - to use the full range while preserving relative intensities. For each - color channel: - 1. Compute histogram - 2. Find cumulative percentiles - 3. Clip and scale intensities to full range + This transform provides two methods for contrast enhancement: + 1. CDF method (default): Uses cumulative distribution function for more gradual adjustment + 2. PIL method: Uses linear scaling like PIL.ImageOps.autocontrast + + The transform can optionally exclude extreme values from both ends of the + intensity range and preserve specific intensity values (e.g., alpha channel). Args: - p (float): probability of applying the transform. Default: 0.5. + cutoff (float): Percentage of pixels to exclude from both ends of the histogram. + Range: [0, 100]. Default: 0 (use full intensity range) + - 0 means use the minimum and maximum intensity values found + - 20 means exclude darkest and brightest 20% of pixels + ignore (int, optional): Intensity value to preserve (e.g., alpha channel). + Range: [0, 255]. Default: None + - If specified, this intensity value will not be modified + - Useful for images with alpha channel or special marker values + method (Literal["cdf", "pil"]): Algorithm to use for contrast enhancement. + Default: "cdf" + - "cdf": Uses cumulative distribution for smoother adjustment + - "pil": Uses linear scaling like PIL.ImageOps.autocontrast + p (float): Probability of applying the transform. Default: 0.5 Targets: image Image types: uint8, float32 + + Note: + - The transform processes each color channel independently + - For grayscale images, only one channel is processed + - The output maintains the same dtype as input + - Empty or single-color channels remain unchanged + + Examples: + >>> import albumentations as A + >>> # Basic usage + >>> transform = A.AutoContrast(p=1.0) + >>> + >>> # Exclude extreme values + >>> transform = A.AutoContrast(cutoff=20, p=1.0) + >>> + >>> # Preserve alpha channel + >>> transform = A.AutoContrast(ignore=255, p=1.0) + >>> + >>> # Use PIL-like contrast enhancement + >>> transform = A.AutoContrast(method="pil", p=1.0) """ + class InitSchema(BaseTransformInitSchema): + cutoff: float = Field(ge=0, le=100) + ignore: int | None = Field(ge=0, le=255) + method: Literal["cdf", "pil"] + def __init__( self, + cutoff: float = 0, + ignore: int | None = None, + method: Literal["cdf", "pil"] = "cdf", p: float = 0.5, ): super().__init__(p=p) + self.cutoff = cutoff + self.ignore = ignore + self.method = method def apply(self, img: np.ndarray, **params: Any) -> np.ndarray: - return fmain.auto_contrast(img) + return fmain.auto_contrast(img, self.cutoff, self.ignore, self.method) + + @batch_transform("channel", has_batch_dim=True, has_depth_dim=False) + def apply_to_images(self, images: np.ndarray, **params: Any) -> np.ndarray: + return self.apply(images, **params) + + @batch_transform("channel", has_batch_dim=False, has_depth_dim=True) + def apply_to_volume(self, volume: np.ndarray, **params: Any) -> np.ndarray: + return self.apply(volume, **params) + + @batch_transform("channel", has_batch_dim=True, has_depth_dim=True) + def apply_to_volumes(self, volumes: np.ndarray, **params: Any) -> np.ndarray: + return self.apply(volumes, **params) def get_transform_init_args_names(self) -> tuple[str, ...]: - return () + return "cutoff", "ignore", "method" diff --git a/tests/aug_definitions.py b/tests/aug_definitions.py index bdb983f7b..65218d723 100644 --- a/tests/aug_definitions.py +++ b/tests/aug_definitions.py @@ -404,7 +404,10 @@ [A.PlasmaShadow, {}], [A.Illumination, {}], [A.ThinPlateSpline, {}], - [A.AutoContrast, {}], + [A.AutoContrast, [ + {"cutoff": 0, "ignore": None, "method": "cdf"}, + {"cutoff": 0, "ignore": None, "method": "pil"}, + ]], [A.PadIfNeeded3D, {"min_zyx": (300, 200, 400), "pad_divisor_zyx": (10, 10, 10), "position": "center", "fill": 10, "fill_mask": 20}], [A.Pad3D, {"padding": 10}], [A.CenterCrop3D, {"size": (2, 30, 30)}], diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 96ecb30f4..a7222c2c7 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -1211,7 +1211,7 @@ def test_image_compression_quality_with_patterns(image_type): ], ) def test_auto_contrast(img, expected): - result = fmain.auto_contrast(img) + result = fmain.auto_contrast(img, cutoff=0, ignore=None, method="cdf") if expected == "constant": ( @@ -1608,3 +1608,250 @@ def test_plasma_pattern_statistical_properties(): # Test distribution is roughly symmetric median = np.median(pattern) assert 0.3 <= median <= 0.7 # Wider bounds to account for randomness + + +@pytest.mark.parametrize( + ["hist", "min_intensity", "max_intensity", "max_value", "method", "expected_output"], + [ + # Test PIL method with simple range + ( + np.array([1, 1, 1]), # Simple histogram + 0, 2, # min/max intensities + 255, # max value + "pil", + np.array([0, 128, 255]), # Expected LUT for first 3 values + ), + + # Test CDF method with simple range + ( + np.array([1, 1, 1]), # Equal distribution + 0, 2, # min/max intensities + 255, # max value + "cdf", + np.array([0, 128, 255]), # Expected LUT for first 3 values + ), + + # Test empty histogram with PIL method + ( + np.zeros(256), # Empty histogram + 0, 255, + 255, + "pil", + np.arange(256, dtype=np.uint8), # Should return identity LUT + ), + + # Test empty histogram with CDF method + ( + np.zeros(256), + 0, 255, + 255, + "cdf", + np.arange(256, dtype=np.uint8), # Should return identity LUT + ), + + # Test single value histogram with PIL method + ( + np.array([0, 10, 0]), # Single non-zero value + 1, 1, + 255, + "pil", + np.zeros(256, dtype=np.uint8), # Should map everything to 0 + ), + + # Test narrow range with PIL method + ( + np.array([0, 1, 1, 1, 0]), + 1, 3, + 255, + "pil", + np.array([0, 0, 128, 255, 255, *[255]*(256-5)]), # Linear scaling + ), + ] +) +def test_create_contrast_lut( + hist: np.ndarray, + min_intensity: int, + max_intensity: int, + max_value: int, + method: str, + expected_output: np.ndarray +): + """Test create_contrast_lut function with various inputs.""" + # If hist is smaller than 256, pad it + if len(hist) < 256: + hist = np.pad(hist, (0, 256 - len(hist))) + + # Generate LUT + lut = fmain.create_contrast_lut( + hist=hist, + min_intensity=min_intensity, + max_intensity=max_intensity, + max_value=max_value, + method=method + ) + + # Basic checks + assert isinstance(lut, np.ndarray) + assert lut.dtype == np.uint8 + assert lut.shape == (256,) + assert np.all(lut >= 0) + assert np.all(lut <= max_value) + + # Check if first few values match expected + assert np.array_equal( + lut[:len(expected_output)], + expected_output[:len(expected_output)] + ) + + +def test_create_contrast_lut_properties(): + """Test mathematical properties of the lookup tables.""" + hist = np.random.randint(0, 100, 256) + max_value = 255 + + # Test monotonicity for PIL method + lut_pil = fmain.create_contrast_lut( + hist=hist, + min_intensity=50, + max_intensity=200, + max_value=max_value, + method="pil" + ) + assert np.all(np.diff(lut_pil) >= 0), "PIL LUT should be monotonically increasing" + + # Test CDF method preserves relative frequencies + lut_cdf = fmain.create_contrast_lut( + hist=hist, + min_intensity=50, + max_intensity=200, + max_value=max_value, + method="cdf" + ) + assert np.all(np.diff(lut_cdf) >= 0), "CDF LUT should be monotonically increasing" + + + +@pytest.mark.parametrize( + ["hist", "cutoff", "expected"], + [ + # Test with no cutoff + ( + np.array([0, 1, 1, 1, 0]), # Simple histogram + 0, + (1, 3) # Should return first and last non-zero indices + ), + + # Test with empty histogram + ( + np.zeros(256), + 0, + (0, 0) # Should return (0, 0) for empty histogram + ), + + # Test with single value histogram + ( + np.array([0, 10, 0, 0]), + 0, + (1, 1) # Should return same index for single peak + ), + + # Test with 20% cutoff + ( + np.array([10, 10, 10, 10, 10]), # Uniform histogram + 20, + (1, 4) # Should cut 20% from each end + ), + + # Test with 50% cutoff + ( + np.array([10, 10, 10, 10, 10]), + 50, + (2, 2) # Should converge to middle + ), + + # Test with asymmetric histogram + ( + np.array([50, 10, 10, 10, 20]), # More weight on edges + 20, + (1, 4) # Should adjust for weight distribution + ), + + # Test with all pixels in one bin + ( + np.array([0, 100, 0, 0]), + 10, + (1, 1) # Should return the peak location + ), + ] +) +def test_get_histogram_bounds(hist: np.ndarray, cutoff: float, expected: tuple[int, int]): + """Test get_histogram_bounds with various histogram shapes and cutoffs.""" + min_intensity, max_intensity = fmain.get_histogram_bounds(hist, cutoff) + + assert isinstance(min_intensity, int) + assert isinstance(max_intensity, int) + assert min_intensity <= max_intensity + assert min_intensity >= 0 + assert max_intensity < len(hist) + assert (min_intensity, max_intensity) == expected + + + +def test_get_histogram_bounds_properties(): + """Test mathematical and logical properties of the bounds.""" + np.random.seed(42) # For reproducibility + hist = np.random.randint(0, 100, 256) + + cutoffs = [0, 10, 25, 49] + previous_range = 256 + + for cutoff in cutoffs: + min_intensity, max_intensity = fmain.get_histogram_bounds(hist, cutoff) + + # Range should decrease as cutoff increases + current_range = max_intensity - min_intensity + 1 + assert current_range <= previous_range, \ + f"Range should decrease with increasing cutoff. Cutoff: {cutoff}" + previous_range = current_range + + # Verify percentage of pixels included + if cutoff > 0: + pixels_before_min = hist[:min_intensity].sum() + total_pixels = hist.sum() + + expected_cut = total_pixels * cutoff / 100 + relative_error = abs(pixels_before_min - expected_cut) / expected_cut + assert relative_error <= 0.1, \ + f"Lower bound cut incorrect for cutoff {cutoff}" + + +def test_get_histogram_bounds_edge_cases(): + """Test edge cases for get_histogram_bounds.""" + # Test with all zeros except edges + hist = np.zeros(256) + hist[0] = hist[-1] = 100 + min_intensity, max_intensity = fmain.get_histogram_bounds(hist, 0) + assert (min_intensity, max_intensity) == (0, 255) + + # Test with single non-zero value + hist = np.zeros(256) + hist[128] = 100 + min_intensity, max_intensity = fmain.get_histogram_bounds(hist, 0) + assert min_intensity == max_intensity == 128 + + # Test with constant histogram and 25% cutoff + hist = np.ones(256) + min_intensity, max_intensity = fmain.get_histogram_bounds(hist, 25) + # With uniform distribution, should cut 25% from each end + assert min_intensity == 64 # 256 * 0.25 + assert max_intensity == 191 # 256 * 0.75 - 1 + + +def test_get_histogram_bounds_numerical_stability(): + """Test numerical stability with very large and small values.""" + # Test with very large values + hist = np.ones(256) * 1e6 + min_intensity, max_intensity = fmain.get_histogram_bounds(hist, 10) + # With uniform distribution, should cut 10% from each end + assert min_intensity == 25 # 256 * 0.10 + assert max_intensity == 230 # 256 * 0.90 - 1