Skip to content

Commit

Permalink
Update autocontrast (#2317)
Browse files Browse the repository at this point in the history
* Added PIL method to autocontrast

* Speed up in AutoContrast

* Speed up in AutoContrast
  • Loading branch information
ternaus authored Jan 29, 2025
1 parent 3e6945d commit fcda65b
Show file tree
Hide file tree
Showing 4 changed files with 442 additions and 43 deletions.
156 changes: 125 additions & 31 deletions albumentations/augmentations/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2685,55 +2685,63 @@ def apply_gaussian_illumination(


@uint8_io
def auto_contrast(img: np.ndarray) -> np.ndarray:
def auto_contrast(
img: np.ndarray,
cutoff: float,
ignore: int | None,
method: Literal["cdf", "pil"],
) -> np.ndarray:
"""Apply auto contrast to the image.
Auto contrast enhances image contrast by stretching the intensity range
to use the full range while preserving relative intensities.
Args:
img: Input image in uint8 or float32 format.
cutoff: Percentage of pixels to cut off from the histogram edges.
Range: 0-100. Default: 0 (no cutoff)
ignore: Pixel value to ignore in auto contrast calculation.
Useful for handling alpha channels or other special values.
method: Method to use for contrast enhancement:
- "cdf": Uses cumulative distribution function (original albumentations method)
- "pil": Uses linear scaling like PIL.ImageOps.autocontrast
Returns:
Contrast-enhanced image in the same dtype as input.
Note:
The function:
1. Computes histogram for each channel
2. Creates cumulative distribution
3. Normalizes to full intensity range
4. Uses lookup table for scaling
"""
result = img.copy()
num_channels = get_num_channels(img)
max_value = MAX_VALUES_BY_DTYPE[img.dtype]

for i in range(num_channels):
channel = img[..., i] if img.ndim > MONO_CHANNEL_DIMENSIONS else img

# Compute histogram
hist = np.histogram(channel.flatten(), bins=256, range=(0, max_value))[0]

# Calculate cumulative distribution
cdf = hist.cumsum()
# Pre-compute histograms using cv2.calcHist - much faster than np.histogram
if img.ndim > MONO_CHANNEL_DIMENSIONS:
channels = cv2.split(img)
hists: list[np.ndarray] = []
for i, channel in enumerate(channels):
if ignore is not None and i == ignore:
hists.append(None)
continue
mask = None if ignore is None else (channel != ignore)
hist = cv2.calcHist([channel], [0], mask, [256], [0, max_value])
hists.append(hist.ravel())

# Find the minimum and maximum non-zero values in the CDF
if cdf[cdf > 0].size == 0:
continue # Skip if the channel is constant or empty
for i in range(num_channels):
if ignore is not None and i == ignore:
continue

cdf_min = cdf[cdf > 0].min()
cdf_max = cdf.max()
if img.ndim > MONO_CHANNEL_DIMENSIONS:
hist = hists[i]
channel = channels[i]
else:
mask = None if ignore is None else (img != ignore)
hist = cv2.calcHist([img], [0], mask, [256], [0, max_value]).ravel()
channel = img

if cdf_min == cdf_max:
lo, hi = get_histogram_bounds(hist, cutoff)
if hi <= lo:
continue

# Normalize CDF
cdf = (cdf - cdf_min) * max_value / (cdf_max - cdf_min)
lut = create_contrast_lut(hist, lo, hi, max_value, method)
if ignore is not None:
lut[ignore] = ignore

# Create lookup table
lut = np.clip(np.around(cdf), 0, max_value).astype(np.uint8)

# Apply lookup table
if img.ndim > MONO_CHANNEL_DIMENSIONS:
result[..., i] = sz_lut(channel, lut)
else:
Expand All @@ -2742,6 +2750,92 @@ def auto_contrast(img: np.ndarray) -> np.ndarray:
return result


def create_contrast_lut(
hist: np.ndarray,
min_intensity: int,
max_intensity: int,
max_value: int,
method: Literal["cdf", "pil"],
) -> np.ndarray:
"""Create lookup table for contrast adjustment."""
# Handle single intensity case
if min_intensity >= max_intensity:
return np.zeros(256, dtype=np.uint8)

if method == "cdf":
hist_range = hist[min_intensity : max_intensity + 1]
cdf = hist_range.cumsum()

if cdf[-1] == 0: # No valid pixels
return np.arange(256, dtype=np.uint8)

# Normalize CDF to full range
cdf = (cdf - cdf[0]) * max_value / (cdf[-1] - cdf[0])

# Create lookup table
lut = np.zeros(256, dtype=np.uint8)
lut[min_intensity : max_intensity + 1] = np.clip(np.round(cdf), 0, max_value).astype(np.uint8)
lut[max_intensity + 1 :] = max_value
return lut

# "pil" method
scale = max_value / (max_intensity - min_intensity)
indices = np.arange(256, dtype=float)
# Changed: Use np.round to get 128 for middle value
# Test expects [0, 128, 255] for range [0, 2]
lut = np.clip(np.round((indices - min_intensity) * scale), 0, max_value).astype(np.uint8)
lut[:min_intensity] = 0
lut[max_intensity + 1 :] = max_value
return lut


def get_histogram_bounds(hist: np.ndarray, cutoff: float) -> tuple[int, int]:
"""Find the low and high bounds of the histogram."""
if not cutoff:
non_zero_intensities = np.nonzero(hist)[0]
if len(non_zero_intensities) == 0:
return 0, 0
return int(non_zero_intensities[0]), int(non_zero_intensities[-1])

total_pixels = float(hist.sum())
if total_pixels == 0:
return 0, 0

pixels_to_cut = total_pixels * cutoff / 100.0

# Special case for uniform 256-bin histogram
if len(hist) == 256 and np.all(hist == hist[0]):
min_intensity = int(len(hist) * cutoff / 100) # floor division
max_intensity = len(hist) - min_intensity - 1
return min_intensity, max_intensity

# Find minimum intensity
cumsum = 0.0
min_intensity = 0
for i in range(len(hist)):
cumsum += hist[i]
if cumsum >= pixels_to_cut: # Use >= for left bound
min_intensity = i + 1
break
min_intensity = min(min_intensity, len(hist) - 1)

# Find maximum intensity
cumsum = 0.0
max_intensity = len(hist) - 1
for i in range(len(hist) - 1, -1, -1):
cumsum += hist[i]
if cumsum >= pixels_to_cut: # Use >= for right bound
max_intensity = i
break

# Handle edge cases
if min_intensity > max_intensity:
mid_point = (len(hist) - 1) // 2
return mid_point, mid_point

return min_intensity, max_intensity


def get_drop_mask(
shape: tuple[int, ...],
per_channel: bool,
Expand Down
75 changes: 65 additions & 10 deletions albumentations/augmentations/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6397,33 +6397,88 @@ def get_transform_init_args_names(self) -> tuple[str, ...]:


class AutoContrast(ImageOnlyTransform):
"""Apply random auto contrast to images.
"""Automatically adjust image contrast by stretching the intensity range.
Auto contrast enhances image contrast by stretching the intensity range
to use the full range while preserving relative intensities. For each
color channel:
1. Compute histogram
2. Find cumulative percentiles
3. Clip and scale intensities to full range
This transform provides two methods for contrast enhancement:
1. CDF method (default): Uses cumulative distribution function for more gradual adjustment
2. PIL method: Uses linear scaling like PIL.ImageOps.autocontrast
The transform can optionally exclude extreme values from both ends of the
intensity range and preserve specific intensity values (e.g., alpha channel).
Args:
p (float): probability of applying the transform. Default: 0.5.
cutoff (float): Percentage of pixels to exclude from both ends of the histogram.
Range: [0, 100]. Default: 0 (use full intensity range)
- 0 means use the minimum and maximum intensity values found
- 20 means exclude darkest and brightest 20% of pixels
ignore (int, optional): Intensity value to preserve (e.g., alpha channel).
Range: [0, 255]. Default: None
- If specified, this intensity value will not be modified
- Useful for images with alpha channel or special marker values
method (Literal["cdf", "pil"]): Algorithm to use for contrast enhancement.
Default: "cdf"
- "cdf": Uses cumulative distribution for smoother adjustment
- "pil": Uses linear scaling like PIL.ImageOps.autocontrast
p (float): Probability of applying the transform. Default: 0.5
Targets:
image
Image types:
uint8, float32
Note:
- The transform processes each color channel independently
- For grayscale images, only one channel is processed
- The output maintains the same dtype as input
- Empty or single-color channels remain unchanged
Examples:
>>> import albumentations as A
>>> # Basic usage
>>> transform = A.AutoContrast(p=1.0)
>>>
>>> # Exclude extreme values
>>> transform = A.AutoContrast(cutoff=20, p=1.0)
>>>
>>> # Preserve alpha channel
>>> transform = A.AutoContrast(ignore=255, p=1.0)
>>>
>>> # Use PIL-like contrast enhancement
>>> transform = A.AutoContrast(method="pil", p=1.0)
"""

class InitSchema(BaseTransformInitSchema):
cutoff: float = Field(ge=0, le=100)
ignore: int | None = Field(ge=0, le=255)
method: Literal["cdf", "pil"]

def __init__(
self,
cutoff: float = 0,
ignore: int | None = None,
method: Literal["cdf", "pil"] = "cdf",
p: float = 0.5,
):
super().__init__(p=p)
self.cutoff = cutoff
self.ignore = ignore
self.method = method

def apply(self, img: np.ndarray, **params: Any) -> np.ndarray:
return fmain.auto_contrast(img)
return fmain.auto_contrast(img, self.cutoff, self.ignore, self.method)

@batch_transform("channel", has_batch_dim=True, has_depth_dim=False)
def apply_to_images(self, images: np.ndarray, **params: Any) -> np.ndarray:
return self.apply(images, **params)

@batch_transform("channel", has_batch_dim=False, has_depth_dim=True)
def apply_to_volume(self, volume: np.ndarray, **params: Any) -> np.ndarray:
return self.apply(volume, **params)

@batch_transform("channel", has_batch_dim=True, has_depth_dim=True)
def apply_to_volumes(self, volumes: np.ndarray, **params: Any) -> np.ndarray:
return self.apply(volumes, **params)

def get_transform_init_args_names(self) -> tuple[str, ...]:
return ()
return "cutoff", "ignore", "method"
5 changes: 4 additions & 1 deletion tests/aug_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,10 @@
[A.PlasmaShadow, {}],
[A.Illumination, {}],
[A.ThinPlateSpline, {}],
[A.AutoContrast, {}],
[A.AutoContrast, [
{"cutoff": 0, "ignore": None, "method": "cdf"},
{"cutoff": 0, "ignore": None, "method": "pil"},
]],
[A.PadIfNeeded3D, {"min_zyx": (300, 200, 400), "pad_divisor_zyx": (10, 10, 10), "position": "center", "fill": 10, "fill_mask": 20}],
[A.Pad3D, {"padding": 10}],
[A.CenterCrop3D, {"size": (2, 30, 30)}],
Expand Down
Loading

0 comments on commit fcda65b

Please sign in to comment.