From 0131c643d8176be55d20f79270da7b310564511f Mon Sep 17 00:00:00 2001 From: David Hoese Date: Sat, 25 Jun 2022 10:04:00 -0500 Subject: [PATCH 1/6] Add 'use_map_blocks' to `apply_enhancement` utility function --- satpy/enhancements/__init__.py | 38 +++++++++++++++++++++++++--------- 1 file changed, 28 insertions(+), 10 deletions(-) diff --git a/satpy/enhancements/__init__.py b/satpy/enhancements/__init__.py index 4849f9e2f6..1bb783044a 100644 --- a/satpy/enhancements/__init__.py +++ b/satpy/enhancements/__init__.py @@ -49,8 +49,14 @@ def invert(img, *args): return img.invert(*args) -def apply_enhancement(data, func, exclude=None, separate=False, - pass_dask=False): +def apply_enhancement( + data, + func, + exclude=None, + separate=False, + pass_dask=False, + use_map_blocks=True, +): """Apply `func` to the provided data. Args: @@ -60,7 +66,13 @@ def apply_enhancement(data, func, exclude=None, separate=False, in the calculations. separate (bool): Apply `func` one band at a time. Default is False. pass_dask (bool): Pass the underlying dask array instead of the - xarray.DataArray. + xarray.DataArray. + use_map_blocks (bool): If this option and ``pass_dask`` are ``True``, + run the provided function using :func:`dask.array.core.map_blocks`. + This means dask will call the provided function with a single chunk + as a numpy array. If ``False`` the function is passed the + underlying dask array directly. If ``pass_dask`` is ``False`` this + has no effect. """ attrs = data.attrs @@ -97,7 +109,14 @@ def apply_enhancement(data, func, exclude=None, separate=False, if pass_dask: dims = band_data.dims coords = band_data.coords - d_arr = func(band_data.data) + if use_map_blocks: + d_arr = da.map_blocks(func, + band_data.data, + meta=np.array((), dtype=band_data.dtype), + dtype=band_data.dtype, + chunks=band_data.chunks) + else: + d_arr = func(band_data.data) band_data = xr.DataArray(d_arr, dims=dims, coords=coords) else: band_data = func(band_data) @@ -185,15 +204,14 @@ def piecewise_linear_stretch( xp = np.asarray(xp) / reference_scale_factor fp = np.asarray(fp) / reference_scale_factor - def func(band_data, xp, fp, index=None): + def func(band_data, xp, fp): # Interpolate band on [0,1] using "lazy" arrays (put calculations off until the end). - band_data = xr.DataArray(da.clip(band_data.data.map_blocks(np.interp, xp=xp, fp=fp), 0, 1), - coords=band_data.coords, dims=band_data.dims, name=band_data.name, - attrs=band_data.attrs) - return band_data + interp_data = np.interp(band_data, xp=xp, fp=fp) + interp_data = np.clip(interp_data, 0, 1, out=interp_data) + return interp_data func_with_kwargs = partial(func, xp=xp, fp=fp) - return apply_enhancement(img.data, func_with_kwargs, separate=True) + return apply_enhancement(img.data, func_with_kwargs, separate=False, pass_dask=True) def cira_stretch(img, **kwargs): From dc28b1d50bdd4adb91d782dbc3cf3c474638e0f5 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Sat, 25 Jun 2022 15:03:35 -0500 Subject: [PATCH 2/6] Add tests for apply_enhancement and start refactoring --- satpy/enhancements/__init__.py | 72 +++++++++++++------ .../enhancement_tests/test_enhancements.py | 27 +++++++ 2 files changed, 78 insertions(+), 21 deletions(-) diff --git a/satpy/enhancements/__init__.py b/satpy/enhancements/__init__.py index 1bb783044a..cee459be21 100644 --- a/satpy/enhancements/__init__.py +++ b/satpy/enhancements/__init__.py @@ -75,35 +75,65 @@ def apply_enhancement( has no effect. """ - attrs = data.attrs bands = data.coords['bands'].values if exclude is None: exclude = ['A'] if 'A' in bands else [] if separate: - data_arrs = [] - for idx, band_name in enumerate(bands): - band_data = data.sel(bands=[band_name]) - if band_name in exclude: - # don't modify alpha - data_arrs.append(band_data) - continue - - if pass_dask: - dims = band_data.dims - coords = band_data.coords - d_arr = func(band_data.data, index=idx) - band_data = xr.DataArray(d_arr, dims=dims, coords=coords) - else: - band_data = func(band_data, index=idx) + return _enhance_separate_bands( + func, + data, + bands, + exclude, + pass_dask, + use_map_blocks, + ) + + return _enhance_whole_array( + func, + data, + bands, + exclude, + pass_dask, + use_map_blocks, + ) + + +def _enhance_separate_bands(func, data, bands, exclude, pass_dask, use_map_blocks): + attrs = data.attrs + data_arrs = [] + for idx, band_name in enumerate(bands): + band_data = data.sel(bands=[band_name]) + if band_name in exclude: + # don't modify alpha data_arrs.append(band_data) - # we assume that the func can add attrs - attrs.update(band_data.attrs) + continue + + if pass_dask: + dims = band_data.dims + coords = band_data.coords + if use_map_blocks: + d_arr = da.map_blocks(func, + band_data.data, + meta=np.array((), dtype=band_data.dtype), + dtype=band_data.dtype, + chunks=band_data.chunks) + else: + d_arr = func(band_data.data, index=idx) + band_data = xr.DataArray(d_arr, dims=dims, coords=coords) + else: + band_data = func(band_data, index=idx) + data_arrs.append(band_data) + # we assume that the func can add attrs + attrs.update(band_data.attrs) + + data.data = xr.concat(data_arrs, dim='bands').data + data.attrs = attrs + return data - data.data = xr.concat(data_arrs, dim='bands').data - data.attrs = attrs - return data +def _enhance_whole_array(func, data, bands, exclude, pass_dask, use_map_blocks): + attrs = data.attrs band_data = data.sel(bands=[b for b in bands if b not in exclude]) if pass_dask: diff --git a/satpy/tests/enhancement_tests/test_enhancements.py b/satpy/tests/enhancement_tests/test_enhancements.py index 656bdab3df..6ec09a734f 100644 --- a/satpy/tests/enhancement_tests/test_enhancements.py +++ b/satpy/tests/enhancement_tests/test_enhancements.py @@ -64,6 +64,33 @@ def setup_method(self): self.rgb = xr.DataArray(rgb_data, dims=('bands', 'y', 'x'), coords={'bands': ['R', 'G', 'B']}) + @pytest.mark.parametrize( + ("pass_dask", "use_map_blocks", "exp_call_cls"), + [ + (False, False, xr.DataArray), + (False, True, xr.DataArray), # no map_blocks + (True, False, da.Array), + (True, True, np.ndarray), + ], + ) + @pytest.mark.parametrize("input_data_name", ["ch1", "ch2", "rgb"]) + def test_apply_enhancement(self, input_data_name, pass_dask, use_map_blocks, exp_call_cls): + """Test the 'apply_enhancement' utility function.""" + from satpy.enhancements import apply_enhancement + + def _enh_func(img): + def _calc_func(data): + assert isinstance(data, exp_call_cls) + return data + + return apply_enhancement(img.data, _calc_func, pass_dask=pass_dask, use_map_blocks=use_map_blocks) + + in_data = getattr(self, input_data_name) + exp_data = in_data.values + if "bands" not in in_data.coords: + exp_data = exp_data[np.newaxis, :, :] + run_and_check_enhancement(_enh_func, in_data, exp_data) + def test_cira_stretch(self): """Test applying the cira_stretch.""" from satpy.enhancements import cira_stretch From 446e39fa76a6ba777d57ee9d8d472f9c3ee676e9 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Sat, 25 Jun 2022 15:51:59 -0500 Subject: [PATCH 3/6] Fix JMA reproduction enhancement --- satpy/enhancements/ahi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/satpy/enhancements/ahi.py b/satpy/enhancements/ahi.py index bafe55f1d6..c852d38e40 100644 --- a/satpy/enhancements/ahi.py +++ b/satpy/enhancements/ahi.py @@ -41,4 +41,4 @@ def func(img_data): output = da.dot(img_data.T, ccm.T) return output.T - apply_enhancement(img.data, func, pass_dask=True) + apply_enhancement(img.data, func, pass_dask=True, use_map_blocks=False) From 9322f0c38b35ef2343da774f0d972c3b770a4315 Mon Sep 17 00:00:00 2001 From: David Hoese Date: Sat, 25 Jun 2022 20:20:09 -0500 Subject: [PATCH 4/6] Refactor apply_enhancement to remove some duplicate code --- satpy/enhancements/__init__.py | 50 ++++++++++++++-------------------- 1 file changed, 21 insertions(+), 29 deletions(-) diff --git a/satpy/enhancements/__init__.py b/satpy/enhancements/__init__.py index cee459be21..e478280ec1 100644 --- a/satpy/enhancements/__init__.py +++ b/satpy/enhancements/__init__.py @@ -109,20 +109,7 @@ def _enhance_separate_bands(func, data, bands, exclude, pass_dask, use_map_block data_arrs.append(band_data) continue - if pass_dask: - dims = band_data.dims - coords = band_data.coords - if use_map_blocks: - d_arr = da.map_blocks(func, - band_data.data, - meta=np.array((), dtype=band_data.dtype), - dtype=band_data.dtype, - chunks=band_data.chunks) - else: - d_arr = func(band_data.data, index=idx) - band_data = xr.DataArray(d_arr, dims=dims, coords=coords) - else: - band_data = func(band_data, index=idx) + band_data = _call_enh_func(func, band_data, pass_dask, use_map_blocks, {"index": idx}) data_arrs.append(band_data) # we assume that the func can add attrs attrs.update(band_data.attrs) @@ -136,20 +123,7 @@ def _enhance_whole_array(func, data, bands, exclude, pass_dask, use_map_blocks): attrs = data.attrs band_data = data.sel(bands=[b for b in bands if b not in exclude]) - if pass_dask: - dims = band_data.dims - coords = band_data.coords - if use_map_blocks: - d_arr = da.map_blocks(func, - band_data.data, - meta=np.array((), dtype=band_data.dtype), - dtype=band_data.dtype, - chunks=band_data.chunks) - else: - d_arr = func(band_data.data) - band_data = xr.DataArray(d_arr, dims=dims, coords=coords) - else: - band_data = func(band_data) + band_data = _call_enh_func(func, band_data, pass_dask, use_map_blocks, {}) attrs.update(band_data.attrs) # combine the new data with the excluded data @@ -157,10 +131,28 @@ def _enhance_whole_array(func, data, bands, exclude, pass_dask, use_map_blocks): dim='bands') data.data = new_data.sel(bands=bands).data data.attrs = attrs - return data +def _call_enh_func(func, band_data_arr, pass_dask, use_map_blocks, extra_kwargs): + if pass_dask: + dims = band_data_arr.dims + coords = band_data_arr.coords + if use_map_blocks: + d_arr = da.map_blocks(func, + band_data_arr.data, + meta=np.array((), dtype=band_data_arr.dtype), + dtype=band_data_arr.dtype, + chunks=band_data_arr.chunks, + **extra_kwargs) + else: + d_arr = func(band_data_arr.data, **extra_kwargs) + band_data_arr = xr.DataArray(d_arr, dims=dims, coords=coords) + else: + band_data_arr = func(band_data_arr, **extra_kwargs) + return band_data_arr + + def crefl_scaling(img, **kwargs): """Apply non-linear stretch used by CREFL-based RGBs.""" LOG.debug("Applying the crefl_scaling") From e7f14f7018955345ec29d3d29eebd889e639f02b Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Fri, 5 Aug 2022 15:53:34 +0200 Subject: [PATCH 5/6] Refactor enhancement application --- satpy/enhancements/__init__.py | 205 +++++++++--------- satpy/enhancements/abi.py | 40 ++-- satpy/enhancements/ahi.py | 21 +- satpy/enhancements/viirs.py | 22 +- .../enhancement_tests/test_enhancements.py | 60 ++++- 5 files changed, 197 insertions(+), 151 deletions(-) diff --git a/satpy/enhancements/__init__.py b/satpy/enhancements/__init__.py index e478280ec1..57eb02f15e 100644 --- a/satpy/enhancements/__init__.py +++ b/satpy/enhancements/__init__.py @@ -19,7 +19,7 @@ import logging import os import warnings -from functools import partial +from functools import partial, wraps from numbers import Number import dask @@ -52,10 +52,7 @@ def invert(img, *args): def apply_enhancement( data, func, - exclude=None, - separate=False, - pass_dask=False, - use_map_blocks=True, + exclude=None ): """Apply `func` to the provided data. @@ -63,67 +60,28 @@ def apply_enhancement( data (xarray.DataArray): Data to be modified inplace. func (callable): Function to be applied to an xarray exclude (iterable): Bands in the 'bands' dimension to not include - in the calculations. - separate (bool): Apply `func` one band at a time. Default is False. - pass_dask (bool): Pass the underlying dask array instead of the - xarray.DataArray. - use_map_blocks (bool): If this option and ``pass_dask`` are ``True``, - run the provided function using :func:`dask.array.core.map_blocks`. - This means dask will call the provided function with a single chunk - as a numpy array. If ``False`` the function is passed the - underlying dask array directly. If ``pass_dask`` is ``False`` this - has no effect. + in the calculations. If not provided or None, the + alpha band if present will be excluded. To include + all channels, pass []. """ bands = data.coords['bands'].values if exclude is None: exclude = ['A'] if 'A' in bands else [] - if separate: - return _enhance_separate_bands( - func, - data, - bands, - exclude, - pass_dask, - use_map_blocks, - ) - return _enhance_whole_array( func, data, bands, - exclude, - pass_dask, - use_map_blocks, + exclude ) -def _enhance_separate_bands(func, data, bands, exclude, pass_dask, use_map_blocks): - attrs = data.attrs - data_arrs = [] - for idx, band_name in enumerate(bands): - band_data = data.sel(bands=[band_name]) - if band_name in exclude: - # don't modify alpha - data_arrs.append(band_data) - continue - - band_data = _call_enh_func(func, band_data, pass_dask, use_map_blocks, {"index": idx}) - data_arrs.append(band_data) - # we assume that the func can add attrs - attrs.update(band_data.attrs) - - data.data = xr.concat(data_arrs, dim='bands').data - data.attrs = attrs - return data - - -def _enhance_whole_array(func, data, bands, exclude, pass_dask, use_map_blocks): +def _enhance_whole_array(func, data, bands, exclude): attrs = data.attrs band_data = data.sel(bands=[b for b in bands if b not in exclude]) - band_data = _call_enh_func(func, band_data, pass_dask, use_map_blocks, {}) + band_data = func(band_data) attrs.update(band_data.attrs) # combine the new data with the excluded data @@ -134,23 +92,46 @@ def _enhance_whole_array(func, data, bands, exclude, pass_dask, use_map_blocks): return data -def _call_enh_func(func, band_data_arr, pass_dask, use_map_blocks, extra_kwargs): - if pass_dask: - dims = band_data_arr.dims - coords = band_data_arr.coords - if use_map_blocks: - d_arr = da.map_blocks(func, - band_data_arr.data, - meta=np.array((), dtype=band_data_arr.dtype), - dtype=band_data_arr.dtype, - chunks=band_data_arr.chunks, - **extra_kwargs) - else: - d_arr = func(band_data_arr.data, **extra_kwargs) - band_data_arr = xr.DataArray(d_arr, dims=dims, coords=coords) - else: - band_data_arr = func(band_data_arr, **extra_kwargs) - return band_data_arr +def on_separate_bands(func): + """Apply `func` one band at a time.""" + @wraps(func) + def wrapper(data, **kwargs): + attrs = data.attrs + data_arrs = [] + for idx, band in enumerate(data.coords['bands'].values): + band_data = func(data.sel(bands=[band]), index=idx, **kwargs) + data_arrs.append(band_data) + # we assume that the func can add attrs + attrs.update(band_data.attrs) + data.data = xr.concat(data_arrs, dim='bands').data + data.attrs = attrs + return data + + return wrapper + + +def on_dask_array(func): + """Pass the underlying dask array to *func* instead of the xarray.DataArray.""" + @wraps(func) + def wrapper(data, **kwargs): + dims = data.dims + coords = data.coords + d_arr = func(data.data, **kwargs) + return xr.DataArray(d_arr, dims=dims, coords=coords) + return wrapper + + +def using_map_blocks(func): + """Run the provided function using :func:`dask.array.core.map_blocks`. + + This means dask will call the provided function with a single chunk + as a numpy array. + """ + @wraps(func) + def wrapper(data, **kwargs): + return da.map_blocks(func, data, meta=np.array((), dtype=data.dtype), dtype=data.dtype, chunks=data.chunks, + **kwargs) + return on_dask_array(wrapper) def crefl_scaling(img, **kwargs): @@ -226,14 +207,16 @@ def piecewise_linear_stretch( xp = np.asarray(xp) / reference_scale_factor fp = np.asarray(fp) / reference_scale_factor - def func(band_data, xp, fp): - # Interpolate band on [0,1] using "lazy" arrays (put calculations off until the end). - interp_data = np.interp(band_data, xp=xp, fp=fp) - interp_data = np.clip(interp_data, 0, 1, out=interp_data) - return interp_data + func_with_kwargs = partial(_piecewise_linear, xp=xp, fp=fp) + return apply_enhancement(img.data, func_with_kwargs) - func_with_kwargs = partial(func, xp=xp, fp=fp) - return apply_enhancement(img.data, func_with_kwargs, separate=False, pass_dask=True) + +@using_map_blocks +def _piecewise_linear(band_data, xp, fp): + # Interpolate band on [0,1] using "lazy" arrays (put calculations off until the end). + interp_data = np.interp(band_data, xp=xp, fp=fp) + interp_data = np.clip(interp_data, 0, 1, out=interp_data) + return interp_data def cira_stretch(img, **kwargs): @@ -313,17 +296,22 @@ def lookup(img, **kwargs): """Assign values to channels based on a table.""" luts = np.array(kwargs['luts'], dtype=np.float32) / 255.0 - def func(band_data, luts=luts, index=-1): - # NaN/null values will become 0 - lut = luts[:, index] if len(luts.shape) == 2 else luts - band_data = band_data.clip(0, lut.size - 1).astype(np.uint8) + partial_lookup_table = partial(_lookup_table, luts=luts) + + return apply_enhancement(img.data, partial_lookup_table) - new_delay = dask.delayed(_lookup_delayed)(lut, band_data) - new_data = da.from_delayed(new_delay, shape=band_data.shape, - dtype=luts.dtype) - return new_data - return apply_enhancement(img.data, func, separate=True, pass_dask=True) +@on_separate_bands +@using_map_blocks +def _lookup_table(band_data, luts=None, index=-1): + # NaN/null values will become 0 + lut = luts[:, index] if len(luts.shape) == 2 else luts + band_data = band_data.clip(0, lut.size - 1).astype(np.uint8) + + new_delay = dask.delayed(_lookup_delayed)(lut, band_data) + new_data = da.from_delayed(new_delay, shape=band_data.shape, + dtype=luts.dtype) + return new_data def colorize(img, **kwargs): @@ -550,14 +538,6 @@ def _read_colormap_data_from_file(filename): return np.loadtxt(filename, delimiter=",") -def _three_d_effect_delayed(band_data, kernel, mode): - """Kernel for running delayed 3D effect creation.""" - from scipy.signal import convolve2d - band_data = band_data.reshape(band_data.shape[1:]) - new_data = convolve2d(band_data, kernel, mode=mode) - return new_data.reshape((1, band_data.shape[0], band_data.shape[1])) - - def three_d_effect(img, **kwargs): """Create 3D effect using convolution.""" w = kwargs.get('weight', 1) @@ -567,14 +547,27 @@ def three_d_effect(img, **kwargs): [-w, 0, w]]) mode = kwargs.get('convolve_mode', 'same') - def func(band_data, kernel=kernel, mode=mode, index=None): - del index + partial_three_d_effect = partial(_three_d_effect, kernel=kernel, mode=mode) + + return apply_enhancement(img.data, partial_three_d_effect) + - delay = dask.delayed(_three_d_effect_delayed)(band_data, kernel, mode) - new_data = da.from_delayed(delay, shape=band_data.shape, dtype=band_data.dtype) - return new_data +@on_separate_bands +@using_map_blocks +def _three_d_effect(band_data, kernel=None, mode=None, index=None): + del index - return apply_enhancement(img.data, func, separate=True, pass_dask=True) + delay = dask.delayed(_three_d_effect_delayed)(band_data, kernel, mode) + new_data = da.from_delayed(delay, shape=band_data.shape, dtype=band_data.dtype) + return new_data + + +def _three_d_effect_delayed(band_data, kernel, mode): + """Kernel for running delayed 3D effect creation.""" + from scipy.signal import convolve2d + band_data = band_data.reshape(band_data.shape[1:]) + new_data = convolve2d(band_data, kernel, mode=mode) + return new_data.reshape((1, band_data.shape[0], band_data.shape[1])) def btemp_threshold(img, min_in, max_in, threshold, threshold_out=None, **kwargs): @@ -603,10 +596,16 @@ def btemp_threshold(img, min_in, max_in, threshold, threshold_out=None, **kwargs high_factor = threshold_out / (max_in - threshold) high_offset = high_factor * max_in - def _bt_threshold(band_data): - # expects dask array to be passed - return da.where(band_data >= threshold, - high_offset - high_factor * band_data, - low_offset - low_factor * band_data) + partial_bt_threshold = partial(_bt_threshold, threshold=threshold, + high_offset=high_offset, high_factor=high_factor, + low_offset=low_offset, low_factor=low_factor) + + return apply_enhancement(img.data, partial_bt_threshold) + - return apply_enhancement(img.data, _bt_threshold, pass_dask=True) +@using_map_blocks +def _bt_threshold(band_data, threshold, high_offset, high_factor, low_offset, low_factor): + # expects dask array to be passed + return da.where(band_data >= threshold, + high_offset - high_factor * band_data, + low_offset - low_factor * band_data) diff --git a/satpy/enhancements/abi.py b/satpy/enhancements/abi.py index da246f51d9..871c097042 100644 --- a/satpy/enhancements/abi.py +++ b/satpy/enhancements/abi.py @@ -16,29 +16,31 @@ # satpy. If not, see . """Enhancement functions specific to the ABI sensor.""" -from satpy.enhancements import apply_enhancement +from satpy.enhancements import apply_enhancement, using_map_blocks def cimss_true_color_contrast(img, **kwargs): """Scale data based on CIMSS True Color recipe for AWIPS.""" - def func(img_data): - """Perform per-chunk enhancement. + apply_enhancement(img.data, _cimss_true_color_contrast) - Code ported from Kaba Bah's AWIPS python plugin for creating the - CIMSS Natural (True) Color image in AWIPS. AWIPS provides that python - code the image data on a 0-255 scale. Satpy gives this function the - data on a 0-1.0 scale (assuming linear stretching and sqrt - enhancements have already been applied). - """ - max_value = 1.0 - acont = (255.0 / 10.0) / 255.0 - amax = (255.0 + 4.0) / 255.0 - amid = 1.0 / 2.0 - afact = (amax * (acont + max_value) / (max_value * (amax - acont))) - aband = (afact * (img_data - amid) + amid) - aband[aband <= 10 / 255.0] = 0 - aband[aband >= 1.0] = 1.0 - return aband +@using_map_blocks +def _cimss_true_color_contrast(img_data): + """Perform per-chunk enhancement. - apply_enhancement(img.data, func, pass_dask=True) + Code ported from Kaba Bah's AWIPS python plugin for creating the + CIMSS Natural (True) Color image in AWIPS. AWIPS provides that python + code the image data on a 0-255 scale. Satpy gives this function the + data on a 0-1.0 scale (assuming linear stretching and sqrt + enhancements have already been applied). + + """ + max_value = 1.0 + acont = (255.0 / 10.0) / 255.0 + amax = (255.0 + 4.0) / 255.0 + amid = 1.0 / 2.0 + afact = (amax * (acont + max_value) / (max_value * (amax - acont))) + aband = (afact * (img_data - amid) + amid) + aband[aband <= 10 / 255.0] = 0 + aband[aband >= 1.0] = 1.0 + return aband diff --git a/satpy/enhancements/ahi.py b/satpy/enhancements/ahi.py index c852d38e40..a6896b63c5 100644 --- a/satpy/enhancements/ahi.py +++ b/satpy/enhancements/ahi.py @@ -18,7 +18,7 @@ import dask.array as da import numpy as np -from satpy.enhancements import apply_enhancement +from satpy.enhancements import apply_enhancement, on_dask_array def jma_true_color_reproduction(img, **kwargs): @@ -31,14 +31,15 @@ def jma_true_color_reproduction(img, **kwargs): Colorado State University—CIRA https://www.jma.go.jp/jma/jma-eng/satellite/introduction/TCR.html """ + apply_enhancement(img.data, _jma_true_color_reproduction) - def func(img_data): - ccm = np.array([ - [1.1759, 0.0561, -0.1322], - [-0.0386, 0.9587, 0.0559], - [-0.0189, -0.1161, 1.0777] - ]) - output = da.dot(img_data.T, ccm.T) - return output.T - apply_enhancement(img.data, func, pass_dask=True, use_map_blocks=False) +@on_dask_array +def _jma_true_color_reproduction(img_data): + ccm = np.array([ + [1.1759, 0.0561, -0.1322], + [-0.0386, 0.9587, 0.0559], + [-0.0189, -0.1161, 1.0777] + ]) + output = da.dot(img_data.T, ccm.T) + return output.T diff --git a/satpy/enhancements/viirs.py b/satpy/enhancements/viirs.py index 9bd90200b0..e988fb53fd 100644 --- a/satpy/enhancements/viirs.py +++ b/satpy/enhancements/viirs.py @@ -18,7 +18,7 @@ import numpy as np from trollimage.colormap import Colormap -from satpy.enhancements import apply_enhancement +from satpy.enhancements import apply_enhancement, using_map_blocks def water_detection(img, **kwargs): @@ -30,14 +30,16 @@ def water_detection(img, **kwargs): palette = kwargs['palettes'] palette['colors'] = tuple(map(tuple, palette['colors'])) - def func(img_data): - data = np.asarray(img_data) - data[data == 150] = 31 - data[data == 199] = 18 - data[data >= 200] = data[data >= 200] - 100 - - return data - - apply_enhancement(img.data, func, pass_dask=True) + apply_enhancement(img.data, _water_detection) cm = Colormap(*palette['colors']) img.palettize(cm) + + +@using_map_blocks +def _water_detection(img_data): + data = np.asarray(img_data) + data[data == 150] = 31 + data[data == 199] = 18 + data[data >= 200] = data[data >= 200] - 100 + + return data diff --git a/satpy/tests/enhancement_tests/test_enhancements.py b/satpy/tests/enhancement_tests/test_enhancements.py index 6ec09a734f..66dc8100d0 100644 --- a/satpy/tests/enhancement_tests/test_enhancements.py +++ b/satpy/tests/enhancement_tests/test_enhancements.py @@ -27,7 +27,7 @@ import pytest import xarray as xr -from satpy.enhancements import create_colormap +from satpy.enhancements import create_colormap, on_dask_array, on_separate_bands, using_map_blocks def run_and_check_enhancement(func, data, expected, **kwargs): @@ -47,6 +47,11 @@ def run_and_check_enhancement(func, data, expected, **kwargs): np.testing.assert_allclose(img.data.values, expected, atol=1.e-6, rtol=0) +def identical_decorator(func): + """Decorate but do nothing.""" + return func + + class TestEnhancementStretch: """Class for testing enhancements in satpy.enhancements.""" @@ -65,16 +70,15 @@ def setup_method(self): coords={'bands': ['R', 'G', 'B']}) @pytest.mark.parametrize( - ("pass_dask", "use_map_blocks", "exp_call_cls"), + ("decorator", "exp_call_cls"), [ - (False, False, xr.DataArray), - (False, True, xr.DataArray), # no map_blocks - (True, False, da.Array), - (True, True, np.ndarray), + (identical_decorator, xr.DataArray), + (on_dask_array, da.Array), + (using_map_blocks, np.ndarray), ], ) @pytest.mark.parametrize("input_data_name", ["ch1", "ch2", "rgb"]) - def test_apply_enhancement(self, input_data_name, pass_dask, use_map_blocks, exp_call_cls): + def test_apply_enhancement(self, input_data_name, decorator, exp_call_cls): """Test the 'apply_enhancement' utility function.""" from satpy.enhancements import apply_enhancement @@ -82,8 +86,8 @@ def _enh_func(img): def _calc_func(data): assert isinstance(data, exp_call_cls) return data - - return apply_enhancement(img.data, _calc_func, pass_dask=pass_dask, use_map_blocks=use_map_blocks) + decorated_func = decorator(_calc_func) + return apply_enhancement(img.data, decorated_func) in_data = getattr(self, input_data_name) exp_data = in_data.values @@ -444,3 +448,41 @@ def test_cmap_list(self): assert cmap.values.shape[0] == 4 assert cmap.values[0] == 2 assert cmap.values[-1] == 8 + + +def test_on_separate_bands(): + """Test the `on_separate_bands` decorator.""" + def func(array, index, gain=2): + return xr.DataArray(np.ones(array.shape, dtype=array.dtype) * index * gain, + coords=array.coords, dims=array.dims, attrs=array.attrs) + + separate_func = on_separate_bands(func) + arr = xr.DataArray(np.zeros((3, 10, 10)), dims=['bands', 'y', 'x'], coords={"bands": ["R", "G", "B"]}) + assert separate_func(arr).shape == arr.shape + assert all(separate_func(arr, gain=1).values[:, 0, 0] == [0, 1, 2]) + + +def test_using_map_blocks(): + """Test the `using_map_blocks` decorator.""" + def func(np_array, block_info=None): + value = block_info[0]['chunk-location'][-1] + return np.ones(np_array.shape) * value + + map_blocked_func = using_map_blocks(func) + arr = xr.DataArray(da.zeros((3, 10, 10), dtype=int, chunks=5), dims=['bands', 'y', 'x']) + res = map_blocked_func(arr) + assert res.shape == arr.shape + assert res[0, 0, 0].compute() != res[0, 9, 9].compute() + + +def test_on_dask_array(): + """Test the `on_dask_array` decorator.""" + def func(dask_array): + if not isinstance(dask_array, da.core.Array): + pytest.fail("Array is not a dask array") + return dask_array + + dask_func = on_dask_array(func) + arr = xr.DataArray(da.zeros((3, 10, 10), dtype=int, chunks=5), dims=['bands', 'y', 'x']) + res = dask_func(arr) + assert res.shape == arr.shape From d82d36d74bfb5df1036bf7d4314603d5eec5916a Mon Sep 17 00:00:00 2001 From: Martin Raspaud Date: Fri, 5 Aug 2022 17:02:21 +0200 Subject: [PATCH 6/6] Refactor out apply_enhancement --- satpy/enhancements/__init__.py | 121 ++++++++---------- satpy/enhancements/abi.py | 5 +- satpy/enhancements/ahi.py | 5 +- satpy/enhancements/viirs.py | 5 +- .../enhancement_tests/test_enhancements.py | 4 +- 5 files changed, 66 insertions(+), 74 deletions(-) diff --git a/satpy/enhancements/__init__.py b/satpy/enhancements/__init__.py index 57eb02f15e..bf15e0de4c 100644 --- a/satpy/enhancements/__init__.py +++ b/satpy/enhancements/__init__.py @@ -19,7 +19,8 @@ import logging import os import warnings -from functools import partial, wraps +from collections import namedtuple +from functools import wraps from numbers import Number import dask @@ -49,51 +50,39 @@ def invert(img, *args): return img.invert(*args) -def apply_enhancement( - data, - func, - exclude=None -): - """Apply `func` to the provided data. - - Args: - data (xarray.DataArray): Data to be modified inplace. - func (callable): Function to be applied to an xarray - exclude (iterable): Bands in the 'bands' dimension to not include - in the calculations. If not provided or None, the - alpha band if present will be excluded. To include - all channels, pass []. - - """ - bands = data.coords['bands'].values - if exclude is None: +def exclude_alpha(func): + """Exclude the alpha channel from the DataArray before further processing.""" + @wraps(func) + def wrapper(data, **kwargs): + bands = data.coords['bands'].values exclude = ['A'] if 'A' in bands else [] + band_data = data.sel(bands=[b for b in bands + if b not in exclude]) + band_data = func(band_data, **kwargs) + + attrs = data.attrs + attrs.update(band_data.attrs) + # combine the new data with the excluded data + new_data = xr.concat([band_data, data.sel(bands=exclude)], + dim='bands') + data.data = new_data.sel(bands=bands).data + data.attrs = attrs + return data + return wrapper - return _enhance_whole_array( - func, - data, - bands, - exclude - ) +def on_separate_bands(func): + """Apply `func` one band of the DataArray at a time. -def _enhance_whole_array(func, data, bands, exclude): - attrs = data.attrs - band_data = data.sel(bands=[b for b in bands - if b not in exclude]) - band_data = func(band_data) + If this decorator is to be applied along with `on_dask_array`, this decorator has to be applied first, eg:: - attrs.update(band_data.attrs) - # combine the new data with the excluded data - new_data = xr.concat([band_data, data.sel(bands=exclude)], - dim='bands') - data.data = new_data.sel(bands=bands).data - data.attrs = attrs - return data + @on_separate_bands + @on_dask_array + def my_enhancement_function(data): + ... -def on_separate_bands(func): - """Apply `func` one band at a time.""" + """ @wraps(func) def wrapper(data, **kwargs): attrs = data.attrs @@ -207,10 +196,10 @@ def piecewise_linear_stretch( xp = np.asarray(xp) / reference_scale_factor fp = np.asarray(fp) / reference_scale_factor - func_with_kwargs = partial(_piecewise_linear, xp=xp, fp=fp) - return apply_enhancement(img.data, func_with_kwargs) + return _piecewise_linear(img.data, xp=xp, fp=fp) +@exclude_alpha @using_map_blocks def _piecewise_linear(band_data, xp, fp): # Interpolate band on [0,1] using "lazy" arrays (put calculations off until the end). @@ -225,18 +214,19 @@ def cira_stretch(img, **kwargs): Applicable only for visible channels. """ LOG.debug("Applying the cira-stretch") + return _cira_stretch(img.data) - def func(band_data): - log_root = np.log10(0.0223) - denom = (1.0 - log_root) * 0.75 - band_data *= 0.01 - band_data = band_data.clip(np.finfo(float).eps) - band_data = np.log10(band_data) - band_data -= log_root - band_data /= denom - return band_data - return apply_enhancement(img.data, func) +@exclude_alpha +def _cira_stretch(band_data): + log_root = np.log10(0.0223) + denom = (1.0 - log_root) * 0.75 + band_data *= 0.01 + band_data = band_data.clip(np.finfo(float).eps) + band_data = np.log10(band_data) + band_data -= log_root + band_data /= denom + return band_data def reinhard_to_srgb(img, saturation=1.25, white=100, **kwargs): @@ -295,12 +285,10 @@ def _lookup_delayed(luts, band_data): def lookup(img, **kwargs): """Assign values to channels based on a table.""" luts = np.array(kwargs['luts'], dtype=np.float32) / 255.0 - - partial_lookup_table = partial(_lookup_table, luts=luts) - - return apply_enhancement(img.data, partial_lookup_table) + return _lookup_table(img.data, luts=luts) +@exclude_alpha @on_separate_bands @using_map_blocks def _lookup_table(band_data, luts=None, index=-1): @@ -547,11 +535,10 @@ def three_d_effect(img, **kwargs): [-w, 0, w]]) mode = kwargs.get('convolve_mode', 'same') - partial_three_d_effect = partial(_three_d_effect, kernel=kernel, mode=mode) - - return apply_enhancement(img.data, partial_three_d_effect) + return _three_d_effect(img.data, kernel=kernel, mode=mode) +@exclude_alpha @on_separate_bands @using_map_blocks def _three_d_effect(band_data, kernel=None, mode=None, index=None): @@ -596,16 +583,20 @@ def btemp_threshold(img, min_in, max_in, threshold, threshold_out=None, **kwargs high_factor = threshold_out / (max_in - threshold) high_offset = high_factor * max_in - partial_bt_threshold = partial(_bt_threshold, threshold=threshold, - high_offset=high_offset, high_factor=high_factor, - low_offset=low_offset, low_factor=low_factor) + Coeffs = namedtuple("Coeffs", "factor offset") + high = Coeffs(high_factor, high_offset) + low = Coeffs(low_factor, low_offset) - return apply_enhancement(img.data, partial_bt_threshold) + return _bt_threshold(img.data, + threshold=threshold, + high_coeffs=high, + low_coeffs=low) +@exclude_alpha @using_map_blocks -def _bt_threshold(band_data, threshold, high_offset, high_factor, low_offset, low_factor): +def _bt_threshold(band_data, threshold, high_coeffs, low_coeffs): # expects dask array to be passed return da.where(band_data >= threshold, - high_offset - high_factor * band_data, - low_offset - low_factor * band_data) + high_coeffs.offset - high_coeffs.factor * band_data, + low_coeffs.offset - low_coeffs.factor * band_data) diff --git a/satpy/enhancements/abi.py b/satpy/enhancements/abi.py index 871c097042..ca19b4b252 100644 --- a/satpy/enhancements/abi.py +++ b/satpy/enhancements/abi.py @@ -16,14 +16,15 @@ # satpy. If not, see . """Enhancement functions specific to the ABI sensor.""" -from satpy.enhancements import apply_enhancement, using_map_blocks +from satpy.enhancements import exclude_alpha, using_map_blocks def cimss_true_color_contrast(img, **kwargs): """Scale data based on CIMSS True Color recipe for AWIPS.""" - apply_enhancement(img.data, _cimss_true_color_contrast) + _cimss_true_color_contrast(img.data) +@exclude_alpha @using_map_blocks def _cimss_true_color_contrast(img_data): """Perform per-chunk enhancement. diff --git a/satpy/enhancements/ahi.py b/satpy/enhancements/ahi.py index a6896b63c5..a0f332cfa2 100644 --- a/satpy/enhancements/ahi.py +++ b/satpy/enhancements/ahi.py @@ -18,7 +18,7 @@ import dask.array as da import numpy as np -from satpy.enhancements import apply_enhancement, on_dask_array +from satpy.enhancements import exclude_alpha, on_dask_array def jma_true_color_reproduction(img, **kwargs): @@ -31,9 +31,10 @@ def jma_true_color_reproduction(img, **kwargs): Colorado State University—CIRA https://www.jma.go.jp/jma/jma-eng/satellite/introduction/TCR.html """ - apply_enhancement(img.data, _jma_true_color_reproduction) + _jma_true_color_reproduction(img.data) +@exclude_alpha @on_dask_array def _jma_true_color_reproduction(img_data): ccm = np.array([ diff --git a/satpy/enhancements/viirs.py b/satpy/enhancements/viirs.py index e988fb53fd..6a465f161b 100644 --- a/satpy/enhancements/viirs.py +++ b/satpy/enhancements/viirs.py @@ -18,7 +18,7 @@ import numpy as np from trollimage.colormap import Colormap -from satpy.enhancements import apply_enhancement, using_map_blocks +from satpy.enhancements import exclude_alpha, using_map_blocks def water_detection(img, **kwargs): @@ -30,11 +30,12 @@ def water_detection(img, **kwargs): palette = kwargs['palettes'] palette['colors'] = tuple(map(tuple, palette['colors'])) - apply_enhancement(img.data, _water_detection) + _water_detection(img.data) cm = Colormap(*palette['colors']) img.palettize(cm) +@exclude_alpha @using_map_blocks def _water_detection(img_data): data = np.asarray(img_data) diff --git a/satpy/tests/enhancement_tests/test_enhancements.py b/satpy/tests/enhancement_tests/test_enhancements.py index 66dc8100d0..d40c24b411 100644 --- a/satpy/tests/enhancement_tests/test_enhancements.py +++ b/satpy/tests/enhancement_tests/test_enhancements.py @@ -80,14 +80,12 @@ def setup_method(self): @pytest.mark.parametrize("input_data_name", ["ch1", "ch2", "rgb"]) def test_apply_enhancement(self, input_data_name, decorator, exp_call_cls): """Test the 'apply_enhancement' utility function.""" - from satpy.enhancements import apply_enhancement - def _enh_func(img): def _calc_func(data): assert isinstance(data, exp_call_cls) return data decorated_func = decorator(_calc_func) - return apply_enhancement(img.data, decorated_func) + return decorated_func(img.data) in_data = getattr(self, input_data_name) exp_data = in_data.values