Skip to content

Commit

Permalink
Initial implementation of virtualfile_from_image
Browse files Browse the repository at this point in the history
Enable passing in 3-band images to GMT via a virtualfile mechanism instead of using a temporary GeoTIFF file which requires rioxarray to be installed. Implementation based on virtualfile_from_grid. Made some adjacent changes around put_matrix to handle ndim==3, though a segfault is happening now.
  • Loading branch information
weiji14 committed Sep 29, 2024
1 parent f97c3a4 commit 331f9aa
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 12 deletions.
5 changes: 3 additions & 2 deletions pygmt/clib/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,10 @@ def dataarray_to_matrix(grid):
>>> print(inc)
[2.0, 2.0]
"""
if len(grid.dims) != 2:
if len(grid.dims) not in {2, 3}:
raise GMTInvalidInput(
f"Invalid number of grid dimensions '{len(grid.dims)}'. Must be 2."
f"Invalid number of grid/image dimensions '{len(grid.dims)}'. "
"Must be 2 for grid, or 3 for image."
)
# Extract region and inc from the grid
region = []
Expand Down
77 changes: 67 additions & 10 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@
from pygmt.clib.loading import get_gmt_version, load_libgmt
from pygmt.datatypes import _GMT_DATASET, _GMT_GRID, _GMT_IMAGE
from pygmt.exceptions import GMTCLibError, GMTCLibNoSessionError, GMTInvalidInput
from pygmt.helpers import (
_validate_data_input,
data_kind,
tempfile_from_geojson,
tempfile_from_image,
)
from pygmt.helpers import _validate_data_input, data_kind, tempfile_from_geojson

FAMILIES = [
"GMT_IS_DATASET", # Entity is a data table
Expand Down Expand Up @@ -868,6 +863,10 @@ def _check_dtype_and_dim(self, array, ndim):
... gmttype = ses._check_dtype_and_dim(data, ndim=2)
... gmttype == ses["GMT_FLOAT"]
True
>>> data = np.ones((5, 3, 2), dtype="uint8")
>>> with Session() as ses:
... gmttype = ses._check_dtype_and_dim(data, ndim=3)
... gmttype == ses["GMT_UCHAR"]
"""
# Check that the array has the given number of dimensions
if array.ndim != ndim:
Expand Down Expand Up @@ -1006,9 +1005,9 @@ def put_strings(self, dataset, family, strings):
f"Failed to put strings of type {strings.dtype} into dataset"
)

def put_matrix(self, dataset, matrix, pad=0):
def put_matrix(self, dataset, matrix, pad=0, ndim=2):
"""
Attach a numpy 2-D array to a GMT dataset.
Attach a numpy n-D (2-D or 3-D) array to a GMT dataset.
Use this function to attach numpy array data to a GMT dataset and pass
it to GMT modules. Wraps ``GMT_Put_Matrix``.
Expand Down Expand Up @@ -1048,7 +1047,7 @@ def put_matrix(self, dataset, matrix, pad=0):
restype=ctp.c_int,
)

gmt_type = self._check_dtype_and_dim(matrix, ndim=2)
gmt_type = self._check_dtype_and_dim(matrix, ndim=ndim)
matrix_pointer = matrix.ctypes.data_as(ctp.c_void_p)
status = c_put_matrix(
self.session_pointer, dataset, gmt_type, pad, matrix_pointer
Expand Down Expand Up @@ -1610,6 +1609,64 @@ def virtualfile_from_grid(self, grid):
with self.open_virtualfile(*args) as vfile:
yield vfile

@contextlib.contextmanager
def virtualfile_from_image(self, image: xr.DataArray):
"""
Store a image in a virtual file.
Use the virtual file name to pass in the data in your image to a GMT module.
Images must be :class:`xarray.DataArray` instances.
Context manager (use in a ``with`` block). Yields the virtual file name that you
can pass as an argument to a GMT module call. Closes the virtual file upon exit
of the ``with`` block.
The virtual file will contain the image as a ``GMT_MATRIX`` with extra metadata.
Use this instead of creating a data container and virtual file by hand with
:meth:`pygmt.clib.Session.create_data`, :meth:`pygmt.clib.Session.put_matrix`,
and :meth:`pygmt.clib.Session.open_virtualfile`.
The image data matrix must be C contiguous in memory. If it is not (e.g., it is
a slice of a larger array), the array will be copied to make sure it is.
Parameters
----------
image : :class:`xarray.DataArray`
The image that will be included in the virtual file.
Yields
------
fname : str
The name of virtual file. Pass this as a file name argument to a GMT module.
"""
_gtype = {0: "GMT_GRID_IS_CARTESIAN", 1: "GMT_GRID_IS_GEO"}[image.gmt.gtype]
_reg = {0: "GMT_GRID_NODE_REG", 1: "GMT_GRID_PIXEL_REG"}[image.gmt.registration]

Check warning on line 1645 in pygmt/clib/session.py

View check run for this annotation

Codecov / codecov/patch

pygmt/clib/session.py#L1644-L1645

Added lines #L1644 - L1645 were not covered by tests

# Conversion to a C-contiguous array needs to be done here and not in put_matrix
# because we need to maintain a reference to the copy while it is being used by
# the C API. Otherwise, the array would be garbage collected and the memory
# freed. Creating it in this context manager guarantees that the copy will be
# around until the virtual file is closed. The conversion is implicit in
# dataarray_to_matrix.
matrix, region, inc = dataarray_to_matrix(image)

Check warning on line 1653 in pygmt/clib/session.py

View check run for this annotation

Codecov / codecov/patch

pygmt/clib/session.py#L1653

Added line #L1653 was not covered by tests

family = "GMT_IS_IMAGE|GMT_VIA_MATRIX"
geometry = "GMT_IS_SURFACE"
gmt_image = self.create_data(

Check warning on line 1657 in pygmt/clib/session.py

View check run for this annotation

Codecov / codecov/patch

pygmt/clib/session.py#L1655-L1657

Added lines #L1655 - L1657 were not covered by tests
family,
geometry,
mode=f"GMT_CONTAINER_ONLY|{_gtype}",
ranges=region[0:4], # (xmin, xmax, ymin, ymax) only, leave out (zmin, zmax)
inc=inc[0:2], # (x-inc, y-inc) only, leave out z-inc
registration=_reg,
)
self.put_matrix(gmt_image, matrix, ndim=3)
args = (family, geometry, "GMT_IN|GMT_IS_REFERENCE", gmt_image)
with self.open_virtualfile(*args) as vfile:
yield vfile

Check warning on line 1668 in pygmt/clib/session.py

View check run for this annotation

Codecov / codecov/patch

pygmt/clib/session.py#L1665-L1668

Added lines #L1665 - L1668 were not covered by tests

@contextlib.contextmanager
def virtualfile_from_stringio(self, stringio: io.StringIO):
r"""
Expand Down Expand Up @@ -1796,7 +1853,7 @@ def virtualfile_in( # noqa: PLR0912
"arg": contextlib.nullcontext,
"geojson": tempfile_from_geojson,
"grid": self.virtualfile_from_grid,
"image": tempfile_from_image,
"image": self.virtualfile_from_image,
"stringio": self.virtualfile_from_stringio,
# Note: virtualfile_from_matrix is not used because a matrix can be
# converted to vectors instead, and using vectors allows for better
Expand Down

0 comments on commit 331f9aa

Please sign in to comment.