diff --git a/satpy/enhancements/__init__.py b/satpy/enhancements/__init__.py index 4849f9e2f6..bf15e0de4c 100644 --- a/satpy/enhancements/__init__.py +++ b/satpy/enhancements/__init__.py @@ -19,7 +19,8 @@ import logging import os import warnings -from functools import partial +from collections import namedtuple +from functools import wraps from numbers import Number import dask @@ -49,67 +50,77 @@ def invert(img, *args): return img.invert(*args) -def apply_enhancement(data, func, exclude=None, separate=False, - pass_dask=False): - """Apply `func` to the provided data. +def exclude_alpha(func): + """Exclude the alpha channel from the DataArray before further processing.""" + @wraps(func) + def wrapper(data, **kwargs): + bands = data.coords['bands'].values + exclude = ['A'] if 'A' in bands else [] + band_data = data.sel(bands=[b for b in bands + if b not in exclude]) + band_data = func(band_data, **kwargs) + + attrs = data.attrs + attrs.update(band_data.attrs) + # combine the new data with the excluded data + new_data = xr.concat([band_data, data.sel(bands=exclude)], + dim='bands') + data.data = new_data.sel(bands=bands).data + data.attrs = attrs + return data + return wrapper - Args: - data (xarray.DataArray): Data to be modified inplace. - func (callable): Function to be applied to an xarray - exclude (iterable): Bands in the 'bands' dimension to not include - in the calculations. - separate (bool): Apply `func` one band at a time. Default is False. - pass_dask (bool): Pass the underlying dask array instead of the - xarray.DataArray. - """ - attrs = data.attrs - bands = data.coords['bands'].values - if exclude is None: - exclude = ['A'] if 'A' in bands else [] +def on_separate_bands(func): + """Apply `func` one band of the DataArray at a time. + + If this decorator is to be applied along with `on_dask_array`, this decorator has to be applied first, eg:: + + @on_separate_bands + @on_dask_array + def my_enhancement_function(data): + ... - if separate: + + """ + @wraps(func) + def wrapper(data, **kwargs): + attrs = data.attrs data_arrs = [] - for idx, band_name in enumerate(bands): - band_data = data.sel(bands=[band_name]) - if band_name in exclude: - # don't modify alpha - data_arrs.append(band_data) - continue - - if pass_dask: - dims = band_data.dims - coords = band_data.coords - d_arr = func(band_data.data, index=idx) - band_data = xr.DataArray(d_arr, dims=dims, coords=coords) - else: - band_data = func(band_data, index=idx) + for idx, band in enumerate(data.coords['bands'].values): + band_data = func(data.sel(bands=[band]), index=idx, **kwargs) data_arrs.append(band_data) # we assume that the func can add attrs attrs.update(band_data.attrs) - data.data = xr.concat(data_arrs, dim='bands').data data.attrs = attrs return data - band_data = data.sel(bands=[b for b in bands - if b not in exclude]) - if pass_dask: - dims = band_data.dims - coords = band_data.coords - d_arr = func(band_data.data) - band_data = xr.DataArray(d_arr, dims=dims, coords=coords) - else: - band_data = func(band_data) + return wrapper - attrs.update(band_data.attrs) - # combine the new data with the excluded data - new_data = xr.concat([band_data, data.sel(bands=exclude)], - dim='bands') - data.data = new_data.sel(bands=bands).data - data.attrs = attrs - return data +def on_dask_array(func): + """Pass the underlying dask array to *func* instead of the xarray.DataArray.""" + @wraps(func) + def wrapper(data, **kwargs): + dims = data.dims + coords = data.coords + d_arr = func(data.data, **kwargs) + return xr.DataArray(d_arr, dims=dims, coords=coords) + return wrapper + + +def using_map_blocks(func): + """Run the provided function using :func:`dask.array.core.map_blocks`. + + This means dask will call the provided function with a single chunk + as a numpy array. + """ + @wraps(func) + def wrapper(data, **kwargs): + return da.map_blocks(func, data, meta=np.array((), dtype=data.dtype), dtype=data.dtype, chunks=data.chunks, + **kwargs) + return on_dask_array(wrapper) def crefl_scaling(img, **kwargs): @@ -185,15 +196,16 @@ def piecewise_linear_stretch( xp = np.asarray(xp) / reference_scale_factor fp = np.asarray(fp) / reference_scale_factor - def func(band_data, xp, fp, index=None): - # Interpolate band on [0,1] using "lazy" arrays (put calculations off until the end). - band_data = xr.DataArray(da.clip(band_data.data.map_blocks(np.interp, xp=xp, fp=fp), 0, 1), - coords=band_data.coords, dims=band_data.dims, name=band_data.name, - attrs=band_data.attrs) - return band_data + return _piecewise_linear(img.data, xp=xp, fp=fp) - func_with_kwargs = partial(func, xp=xp, fp=fp) - return apply_enhancement(img.data, func_with_kwargs, separate=True) + +@exclude_alpha +@using_map_blocks +def _piecewise_linear(band_data, xp, fp): + # Interpolate band on [0,1] using "lazy" arrays (put calculations off until the end). + interp_data = np.interp(band_data, xp=xp, fp=fp) + interp_data = np.clip(interp_data, 0, 1, out=interp_data) + return interp_data def cira_stretch(img, **kwargs): @@ -202,18 +214,19 @@ def cira_stretch(img, **kwargs): Applicable only for visible channels. """ LOG.debug("Applying the cira-stretch") + return _cira_stretch(img.data) - def func(band_data): - log_root = np.log10(0.0223) - denom = (1.0 - log_root) * 0.75 - band_data *= 0.01 - band_data = band_data.clip(np.finfo(float).eps) - band_data = np.log10(band_data) - band_data -= log_root - band_data /= denom - return band_data - return apply_enhancement(img.data, func) +@exclude_alpha +def _cira_stretch(band_data): + log_root = np.log10(0.0223) + denom = (1.0 - log_root) * 0.75 + band_data *= 0.01 + band_data = band_data.clip(np.finfo(float).eps) + band_data = np.log10(band_data) + band_data -= log_root + band_data /= denom + return band_data def reinhard_to_srgb(img, saturation=1.25, white=100, **kwargs): @@ -272,18 +285,21 @@ def _lookup_delayed(luts, band_data): def lookup(img, **kwargs): """Assign values to channels based on a table.""" luts = np.array(kwargs['luts'], dtype=np.float32) / 255.0 + return _lookup_table(img.data, luts=luts) - def func(band_data, luts=luts, index=-1): - # NaN/null values will become 0 - lut = luts[:, index] if len(luts.shape) == 2 else luts - band_data = band_data.clip(0, lut.size - 1).astype(np.uint8) - new_delay = dask.delayed(_lookup_delayed)(lut, band_data) - new_data = da.from_delayed(new_delay, shape=band_data.shape, - dtype=luts.dtype) - return new_data +@exclude_alpha +@on_separate_bands +@using_map_blocks +def _lookup_table(band_data, luts=None, index=-1): + # NaN/null values will become 0 + lut = luts[:, index] if len(luts.shape) == 2 else luts + band_data = band_data.clip(0, lut.size - 1).astype(np.uint8) - return apply_enhancement(img.data, func, separate=True, pass_dask=True) + new_delay = dask.delayed(_lookup_delayed)(lut, band_data) + new_data = da.from_delayed(new_delay, shape=band_data.shape, + dtype=luts.dtype) + return new_data def colorize(img, **kwargs): @@ -510,14 +526,6 @@ def _read_colormap_data_from_file(filename): return np.loadtxt(filename, delimiter=",") -def _three_d_effect_delayed(band_data, kernel, mode): - """Kernel for running delayed 3D effect creation.""" - from scipy.signal import convolve2d - band_data = band_data.reshape(band_data.shape[1:]) - new_data = convolve2d(band_data, kernel, mode=mode) - return new_data.reshape((1, band_data.shape[0], band_data.shape[1])) - - def three_d_effect(img, **kwargs): """Create 3D effect using convolution.""" w = kwargs.get('weight', 1) @@ -527,14 +535,26 @@ def three_d_effect(img, **kwargs): [-w, 0, w]]) mode = kwargs.get('convolve_mode', 'same') - def func(band_data, kernel=kernel, mode=mode, index=None): - del index + return _three_d_effect(img.data, kernel=kernel, mode=mode) + + +@exclude_alpha +@on_separate_bands +@using_map_blocks +def _three_d_effect(band_data, kernel=None, mode=None, index=None): + del index - delay = dask.delayed(_three_d_effect_delayed)(band_data, kernel, mode) - new_data = da.from_delayed(delay, shape=band_data.shape, dtype=band_data.dtype) - return new_data + delay = dask.delayed(_three_d_effect_delayed)(band_data, kernel, mode) + new_data = da.from_delayed(delay, shape=band_data.shape, dtype=band_data.dtype) + return new_data - return apply_enhancement(img.data, func, separate=True, pass_dask=True) + +def _three_d_effect_delayed(band_data, kernel, mode): + """Kernel for running delayed 3D effect creation.""" + from scipy.signal import convolve2d + band_data = band_data.reshape(band_data.shape[1:]) + new_data = convolve2d(band_data, kernel, mode=mode) + return new_data.reshape((1, band_data.shape[0], band_data.shape[1])) def btemp_threshold(img, min_in, max_in, threshold, threshold_out=None, **kwargs): @@ -563,10 +583,20 @@ def btemp_threshold(img, min_in, max_in, threshold, threshold_out=None, **kwargs high_factor = threshold_out / (max_in - threshold) high_offset = high_factor * max_in - def _bt_threshold(band_data): - # expects dask array to be passed - return da.where(band_data >= threshold, - high_offset - high_factor * band_data, - low_offset - low_factor * band_data) + Coeffs = namedtuple("Coeffs", "factor offset") + high = Coeffs(high_factor, high_offset) + low = Coeffs(low_factor, low_offset) + + return _bt_threshold(img.data, + threshold=threshold, + high_coeffs=high, + low_coeffs=low) + - return apply_enhancement(img.data, _bt_threshold, pass_dask=True) +@exclude_alpha +@using_map_blocks +def _bt_threshold(band_data, threshold, high_coeffs, low_coeffs): + # expects dask array to be passed + return da.where(band_data >= threshold, + high_coeffs.offset - high_coeffs.factor * band_data, + low_coeffs.offset - low_coeffs.factor * band_data) diff --git a/satpy/enhancements/abi.py b/satpy/enhancements/abi.py index da246f51d9..ca19b4b252 100644 --- a/satpy/enhancements/abi.py +++ b/satpy/enhancements/abi.py @@ -16,29 +16,32 @@ # satpy. If not, see . """Enhancement functions specific to the ABI sensor.""" -from satpy.enhancements import apply_enhancement +from satpy.enhancements import exclude_alpha, using_map_blocks def cimss_true_color_contrast(img, **kwargs): """Scale data based on CIMSS True Color recipe for AWIPS.""" - def func(img_data): - """Perform per-chunk enhancement. + _cimss_true_color_contrast(img.data) - Code ported from Kaba Bah's AWIPS python plugin for creating the - CIMSS Natural (True) Color image in AWIPS. AWIPS provides that python - code the image data on a 0-255 scale. Satpy gives this function the - data on a 0-1.0 scale (assuming linear stretching and sqrt - enhancements have already been applied). - """ - max_value = 1.0 - acont = (255.0 / 10.0) / 255.0 - amax = (255.0 + 4.0) / 255.0 - amid = 1.0 / 2.0 - afact = (amax * (acont + max_value) / (max_value * (amax - acont))) - aband = (afact * (img_data - amid) + amid) - aband[aband <= 10 / 255.0] = 0 - aband[aband >= 1.0] = 1.0 - return aband +@exclude_alpha +@using_map_blocks +def _cimss_true_color_contrast(img_data): + """Perform per-chunk enhancement. - apply_enhancement(img.data, func, pass_dask=True) + Code ported from Kaba Bah's AWIPS python plugin for creating the + CIMSS Natural (True) Color image in AWIPS. AWIPS provides that python + code the image data on a 0-255 scale. Satpy gives this function the + data on a 0-1.0 scale (assuming linear stretching and sqrt + enhancements have already been applied). + + """ + max_value = 1.0 + acont = (255.0 / 10.0) / 255.0 + amax = (255.0 + 4.0) / 255.0 + amid = 1.0 / 2.0 + afact = (amax * (acont + max_value) / (max_value * (amax - acont))) + aband = (afact * (img_data - amid) + amid) + aband[aband <= 10 / 255.0] = 0 + aband[aband >= 1.0] = 1.0 + return aband diff --git a/satpy/enhancements/ahi.py b/satpy/enhancements/ahi.py index bafe55f1d6..a0f332cfa2 100644 --- a/satpy/enhancements/ahi.py +++ b/satpy/enhancements/ahi.py @@ -18,7 +18,7 @@ import dask.array as da import numpy as np -from satpy.enhancements import apply_enhancement +from satpy.enhancements import exclude_alpha, on_dask_array def jma_true_color_reproduction(img, **kwargs): @@ -31,14 +31,16 @@ def jma_true_color_reproduction(img, **kwargs): Colorado State University—CIRA https://www.jma.go.jp/jma/jma-eng/satellite/introduction/TCR.html """ + _jma_true_color_reproduction(img.data) - def func(img_data): - ccm = np.array([ - [1.1759, 0.0561, -0.1322], - [-0.0386, 0.9587, 0.0559], - [-0.0189, -0.1161, 1.0777] - ]) - output = da.dot(img_data.T, ccm.T) - return output.T - apply_enhancement(img.data, func, pass_dask=True) +@exclude_alpha +@on_dask_array +def _jma_true_color_reproduction(img_data): + ccm = np.array([ + [1.1759, 0.0561, -0.1322], + [-0.0386, 0.9587, 0.0559], + [-0.0189, -0.1161, 1.0777] + ]) + output = da.dot(img_data.T, ccm.T) + return output.T diff --git a/satpy/enhancements/viirs.py b/satpy/enhancements/viirs.py index 9bd90200b0..6a465f161b 100644 --- a/satpy/enhancements/viirs.py +++ b/satpy/enhancements/viirs.py @@ -18,7 +18,7 @@ import numpy as np from trollimage.colormap import Colormap -from satpy.enhancements import apply_enhancement +from satpy.enhancements import exclude_alpha, using_map_blocks def water_detection(img, **kwargs): @@ -30,14 +30,17 @@ def water_detection(img, **kwargs): palette = kwargs['palettes'] palette['colors'] = tuple(map(tuple, palette['colors'])) - def func(img_data): - data = np.asarray(img_data) - data[data == 150] = 31 - data[data == 199] = 18 - data[data >= 200] = data[data >= 200] - 100 - - return data - - apply_enhancement(img.data, func, pass_dask=True) + _water_detection(img.data) cm = Colormap(*palette['colors']) img.palettize(cm) + + +@exclude_alpha +@using_map_blocks +def _water_detection(img_data): + data = np.asarray(img_data) + data[data == 150] = 31 + data[data == 199] = 18 + data[data >= 200] = data[data >= 200] - 100 + + return data diff --git a/satpy/tests/enhancement_tests/test_enhancements.py b/satpy/tests/enhancement_tests/test_enhancements.py index 656bdab3df..d40c24b411 100644 --- a/satpy/tests/enhancement_tests/test_enhancements.py +++ b/satpy/tests/enhancement_tests/test_enhancements.py @@ -27,7 +27,7 @@ import pytest import xarray as xr -from satpy.enhancements import create_colormap +from satpy.enhancements import create_colormap, on_dask_array, on_separate_bands, using_map_blocks def run_and_check_enhancement(func, data, expected, **kwargs): @@ -47,6 +47,11 @@ def run_and_check_enhancement(func, data, expected, **kwargs): np.testing.assert_allclose(img.data.values, expected, atol=1.e-6, rtol=0) +def identical_decorator(func): + """Decorate but do nothing.""" + return func + + class TestEnhancementStretch: """Class for testing enhancements in satpy.enhancements.""" @@ -64,6 +69,30 @@ def setup_method(self): self.rgb = xr.DataArray(rgb_data, dims=('bands', 'y', 'x'), coords={'bands': ['R', 'G', 'B']}) + @pytest.mark.parametrize( + ("decorator", "exp_call_cls"), + [ + (identical_decorator, xr.DataArray), + (on_dask_array, da.Array), + (using_map_blocks, np.ndarray), + ], + ) + @pytest.mark.parametrize("input_data_name", ["ch1", "ch2", "rgb"]) + def test_apply_enhancement(self, input_data_name, decorator, exp_call_cls): + """Test the 'apply_enhancement' utility function.""" + def _enh_func(img): + def _calc_func(data): + assert isinstance(data, exp_call_cls) + return data + decorated_func = decorator(_calc_func) + return decorated_func(img.data) + + in_data = getattr(self, input_data_name) + exp_data = in_data.values + if "bands" not in in_data.coords: + exp_data = exp_data[np.newaxis, :, :] + run_and_check_enhancement(_enh_func, in_data, exp_data) + def test_cira_stretch(self): """Test applying the cira_stretch.""" from satpy.enhancements import cira_stretch @@ -417,3 +446,41 @@ def test_cmap_list(self): assert cmap.values.shape[0] == 4 assert cmap.values[0] == 2 assert cmap.values[-1] == 8 + + +def test_on_separate_bands(): + """Test the `on_separate_bands` decorator.""" + def func(array, index, gain=2): + return xr.DataArray(np.ones(array.shape, dtype=array.dtype) * index * gain, + coords=array.coords, dims=array.dims, attrs=array.attrs) + + separate_func = on_separate_bands(func) + arr = xr.DataArray(np.zeros((3, 10, 10)), dims=['bands', 'y', 'x'], coords={"bands": ["R", "G", "B"]}) + assert separate_func(arr).shape == arr.shape + assert all(separate_func(arr, gain=1).values[:, 0, 0] == [0, 1, 2]) + + +def test_using_map_blocks(): + """Test the `using_map_blocks` decorator.""" + def func(np_array, block_info=None): + value = block_info[0]['chunk-location'][-1] + return np.ones(np_array.shape) * value + + map_blocked_func = using_map_blocks(func) + arr = xr.DataArray(da.zeros((3, 10, 10), dtype=int, chunks=5), dims=['bands', 'y', 'x']) + res = map_blocked_func(arr) + assert res.shape == arr.shape + assert res[0, 0, 0].compute() != res[0, 9, 9].compute() + + +def test_on_dask_array(): + """Test the `on_dask_array` decorator.""" + def func(dask_array): + if not isinstance(dask_array, da.core.Array): + pytest.fail("Array is not a dask array") + return dask_array + + dask_func = on_dask_array(func) + arr = xr.DataArray(da.zeros((3, 10, 10), dtype=int, chunks=5), dims=['bands', 'y', 'x']) + res = dask_func(arr) + assert res.shape == arr.shape