diff --git a/pygmt/src/x2sys_cross.py b/pygmt/src/x2sys_cross.py index eadd20dcfb2..0e988af6de8 100644 --- a/pygmt/src/x2sys_cross.py +++ b/pygmt/src/x2sys_cross.py @@ -5,19 +5,19 @@ import contextlib import os from pathlib import Path +from typing import Any, Literal import pandas as pd -from packaging.version import Version from pygmt.clib import Session from pygmt.exceptions import GMTInvalidInput from pygmt.helpers import ( - GMTTempFile, build_arg_list, data_kind, fmt_docstring, kwargs_to_strings, unique_name, use_alias, + validate_output_table_type, ) @@ -71,7 +71,12 @@ def tempfile_from_dftrack(track, suffix): Z="trackvalues", ) @kwargs_to_strings(R="sequence") -def x2sys_cross(tracks=None, outfile=None, **kwargs): +def x2sys_cross( + tracks=None, + output_type: Literal["pandas", "numpy", "file"] = "pandas", + outfile: str | None = None, + **kwargs, +): r""" Calculate crossovers between track data files. @@ -102,11 +107,8 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs): set it will default to $GMT_SHAREDIR/x2sys]. (**Note**: MGD77 files will also be looked for via $MGD77_HOME/mgd77_paths.txt and .gmt files will be searched for via $GMT_SHAREDIR/mgg/gmtfile_paths). - - outfile : str - Optional. The file name for the output ASCII txt file to store the - table in. - + {output_type} + {outfile} tag : str Specify the x2sys TAG which identifies the attributes of this data type. @@ -183,68 +185,57 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs): Returns ------- - crossover_errors : :class:`pandas.DataFrame` or None - Table containing crossover error information. - Return type depends on whether the ``outfile`` parameter is set: - - - :class:`pandas.DataFrame` with (x, y, ..., etc) if ``outfile`` is not - set - - None if ``outfile`` is set (track output will be stored in the set in - ``outfile``) + crossover_errors + Table containing crossover error information. Return type depends on ``outfile`` + and ``output_type``: + + - None if ``outfile`` is set (output will be stored in file set by ``outfile``) + - :class:`pandas.DataFrame` or :class:`numpy.ndarray` if ``outfile`` is not set + (depends on ``output_type``) """ + output_type = validate_output_table_type(output_type, outfile=outfile) + with Session() as lib: - file_contexts = [] + file_contexts: list[contextlib.AbstractContextManager[Any]] = [] for track in tracks: - kind = data_kind(track) - if kind == "file": - file_contexts.append(contextlib.nullcontext(track)) - elif kind == "matrix": - # find suffix (-E) of trackfiles used (e.g. xyz, csv, etc) from - # $X2SYS_HOME/TAGNAME/TAGNAME.tag file - lastline = ( - Path(os.environ["X2SYS_HOME"], kwargs["T"], f"{kwargs['T']}.tag") - .read_text(encoding="utf8") - .strip() - .split("\n")[-1] - ) # e.g. "-Dxyz -Etsv -I1/1" - for item in sorted(lastline.split()): # sort list alphabetically - if item.startswith(("-E", "-D")): # prefer -Etsv over -Dxyz - suffix = item[2:] # e.g. tsv (1st choice) or xyz (2nd choice) - - # Save pandas.DataFrame track data to temporary file - file_contexts.append(tempfile_from_dftrack(track=track, suffix=suffix)) - else: - raise GMTInvalidInput(f"Unrecognized data type: {type(track)}") - - with GMTTempFile(suffix=".txt") as tmpfile: + match data_kind(track): + case "file": + file_contexts.append(contextlib.nullcontext(track)) + case "matrix": + # Find suffix (-E) of trackfiles used (e.g. xyz, csv, etc) from + # $X2SYS_HOME/TAGNAME/TAGNAME.tag file. + tagfile = Path( + os.environ["X2SYS_HOME"], kwargs["T"], f"{kwargs['T']}.tag" + ) + lastline = tagfile.read_text().splitlines()[-1] + # e.g. "-Dxyz -Etsv -I1/1" + for item in sorted(lastline.split()): # sort list alphabetically + if item.startswith(("-E", "-D")): # prefer -Etsv over -Dxyz + # e.g. tsv (1st choice) or xyz (2nd choice) + suffix = item[2:] + + # Save pandas.DataFrame track data to temporary file + file_contexts.append( + tempfile_from_dftrack(track=track, suffix=suffix) + ) + case _: + raise GMTInvalidInput(f"Unrecognized data type: {type(track)}") + + with lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl: with contextlib.ExitStack() as stack: fnames = [stack.enter_context(c) for c in file_contexts] - if outfile is None: - outfile = tmpfile.name lib.call_module( module="x2sys_cross", - args=build_arg_list(kwargs, infile=fnames, outfile=outfile), - ) - - # Read temporary csv output to a pandas table - if outfile == tmpfile.name: # if outfile isn't set, return pd.DataFrame - # Read the tab-separated ASCII table - date_format_kwarg = ( - {"date_format": "ISO8601"} - if Version(pd.__version__) >= Version("2.0.0") - else {} + args=build_arg_list(kwargs, infile=fnames, outfile=vouttbl), ) - table = pd.read_csv( - tmpfile.name, - sep="\t", - header=2, # Column names are on 2nd row - comment=">", # Skip the 3rd row with a ">" - parse_dates=[2, 3], # Datetimes on 3rd and 4th column - **date_format_kwarg, # Parse dates in ISO8601 format on pandas>=2 + result = lib.virtualfile_to_dataset( + vfname=vouttbl, output_type=output_type, header=2 ) - # Remove the "# " from "# x" in the first column - table = table.rename(columns={table.columns[0]: table.columns[0][2:]}) - elif outfile != tmpfile.name: # if outfile is set, output in outfile only - table = None - return table + # Convert 3rd and 4th columns to datetimes. + # These two columns have names "t_1"/"t_2" or "i_1"/"i_2". + # "t_1"/"t_2" means they are datetimes and should be converted. + # "i_1"/"i_2" means they are dummy times (i.e., floating-point values). + if output_type == "pandas" and result.columns[2] == "t_1": + result.iloc[:, 2:4] = result.iloc[:, 2:4].apply(pd.to_datetime) + return result diff --git a/pygmt/tests/test_x2sys_cross.py b/pygmt/tests/test_x2sys_cross.py index c9209bd254a..3c2a8509edf 100644 --- a/pygmt/tests/test_x2sys_cross.py +++ b/pygmt/tests/test_x2sys_cross.py @@ -49,7 +49,11 @@ def test_x2sys_cross_input_file_output_file(): x2sys_init(tag=tag, fmtfile="xyz", force=True) outfile = tmpdir_p / "tmp_coe.txt" output = x2sys_cross( - tracks=["@tut_ship.xyz"], tag=tag, coe="i", outfile=outfile + tracks=["@tut_ship.xyz"], + tag=tag, + coe="i", + outfile=outfile, + output_type="file", ) assert output is None # check that output is None since outfile is set @@ -97,8 +101,8 @@ def test_x2sys_cross_input_dataframe_output_dataframe(tracks): columns = list(output.columns) assert columns[:6] == ["x", "y", "i_1", "i_2", "dist_1", "dist_2"] assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"] - assert output.dtypes["i_1"].type == np.object_ - assert output.dtypes["i_2"].type == np.object_ + assert output.dtypes["i_1"].type == np.float64 + assert output.dtypes["i_2"].type == np.float64 @pytest.mark.usefixtures("mock_x2sys_home") @@ -158,8 +162,8 @@ def test_x2sys_cross_input_dataframe_with_nan(tracks): columns = list(output.columns) assert columns[:6] == ["x", "y", "i_1", "i_2", "dist_1", "dist_2"] assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"] - assert output.dtypes["i_1"].type == np.object_ - assert output.dtypes["i_2"].type == np.object_ + assert output.dtypes["i_1"].type == np.float64 + assert output.dtypes["i_2"].type == np.float64 @pytest.mark.usefixtures("mock_x2sys_home")