Skip to content

Commit

Permalink
pygmt.x2sys_cross: Refactor to use virtualfiles for output tables
Browse files Browse the repository at this point in the history
Co-authored-by: Wei Ji <[email protected]>
  • Loading branch information
seisman and weiji14 authored Jun 9, 2024
1 parent 88eddc7 commit 844594f
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 80 deletions.
111 changes: 58 additions & 53 deletions pygmt/src/x2sys_cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@
import contextlib
import os
from pathlib import Path
from typing import Any

import pandas as pd
from packaging.version import Version
from pygmt.clib import Session
from pygmt.exceptions import GMTInvalidInput
from pygmt.helpers import (
GMTTempFile,
build_arg_list,
data_kind,
fmt_docstring,
Expand Down Expand Up @@ -71,7 +70,9 @@ def tempfile_from_dftrack(track, suffix):
Z="trackvalues",
)
@kwargs_to_strings(R="sequence")
def x2sys_cross(tracks=None, outfile=None, **kwargs):
def x2sys_cross(
tracks=None, outfile: str | None = None, **kwargs
) -> pd.DataFrame | None:
r"""
Calculate crossovers between track data files.
Expand Down Expand Up @@ -103,10 +104,8 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs):
will also be looked for via $MGD77_HOME/mgd77_paths.txt and .gmt
files will be searched for via $GMT_SHAREDIR/mgg/gmtfile_paths).
outfile : str
Optional. The file name for the output ASCII txt file to store the
table in.
outfile
The file name for the output ASCII txt file to store the table in.
tag : str
Specify the x2sys TAG which identifies the attributes of this data
type.
Expand Down Expand Up @@ -183,68 +182,74 @@ def x2sys_cross(tracks=None, outfile=None, **kwargs):
Returns
-------
crossover_errors : :class:`pandas.DataFrame` or None
Table containing crossover error information.
Return type depends on whether the ``outfile`` parameter is set:
- :class:`pandas.DataFrame` with (x, y, ..., etc) if ``outfile`` is not
set
- None if ``outfile`` is set (track output will be stored in the set in
``outfile``)
crossover_errors
Table containing crossover error information. A :class:`pandas.DataFrame` object
is returned if ``outfile`` is not set, otherwise ``None`` is returned and output
will be stored in file set by ``outfile``.
"""
with Session() as lib:
file_contexts = []
for track in tracks:
kind = data_kind(track)
if kind == "file":
# Determine output type based on 'outfile' parameter
output_type = "pandas" if outfile is None else "file"

file_contexts: list[contextlib.AbstractContextManager[Any]] = []
for track in tracks:
match data_kind(track):
case "file":
file_contexts.append(contextlib.nullcontext(track))
elif kind == "matrix":
case "matrix":
# find suffix (-E) of trackfiles used (e.g. xyz, csv, etc) from
# $X2SYS_HOME/TAGNAME/TAGNAME.tag file
lastline = (
Path(os.environ["X2SYS_HOME"], kwargs["T"], f"{kwargs['T']}.tag")
.read_text(encoding="utf8")
.strip()
.split("\n")[-1]
) # e.g. "-Dxyz -Etsv -I1/1"
tagfile = Path(
os.environ["X2SYS_HOME"], kwargs["T"], f"{kwargs['T']}.tag"
)
# Last line is like "-Dxyz -Etsv -I1/1"
lastline = tagfile.read_text(encoding="utf8").splitlines()[-1]
for item in sorted(lastline.split()): # sort list alphabetically
if item.startswith(("-E", "-D")): # prefer -Etsv over -Dxyz
suffix = item[2:] # e.g. tsv (1st choice) or xyz (2nd choice)

# Save pandas.DataFrame track data to temporary file
file_contexts.append(tempfile_from_dftrack(track=track, suffix=suffix))
else:
case _:
raise GMTInvalidInput(f"Unrecognized data type: {type(track)}")

with GMTTempFile(suffix=".txt") as tmpfile:
with Session() as lib:
with lib.virtualfile_out(kind="dataset", fname=outfile) as vouttbl:
with contextlib.ExitStack() as stack:
fnames = [stack.enter_context(c) for c in file_contexts]
if outfile is None:
outfile = tmpfile.name
lib.call_module(
module="x2sys_cross",
args=build_arg_list(kwargs, infile=fnames, outfile=outfile),
)

# Read temporary csv output to a pandas table
if outfile == tmpfile.name: # if outfile isn't set, return pd.DataFrame
# Read the tab-separated ASCII table
date_format_kwarg = (
{"date_format": "ISO8601"}
if Version(pd.__version__) >= Version("2.0.0")
else {}
args=build_arg_list(kwargs, infile=fnames, outfile=vouttbl),
)
table = pd.read_csv(
tmpfile.name,
sep="\t",
header=2, # Column names are on 2nd row
comment=">", # Skip the 3rd row with a ">"
parse_dates=[2, 3], # Datetimes on 3rd and 4th column
**date_format_kwarg, # Parse dates in ISO8601 format on pandas>=2
result = lib.virtualfile_to_dataset(
vfname=vouttbl, output_type=output_type, header=2
)
# Remove the "# " from "# x" in the first column
table = table.rename(columns={table.columns[0]: table.columns[0][2:]})
elif outfile != tmpfile.name: # if outfile is set, output in outfile only
table = None

return table
if output_type == "file":
return result

# Convert 3rd and 4th columns to datetime/timedelta for pandas output.
# These two columns have names "t_1"/"t_2" or "i_1"/"i_2".
# "t_" means absolute datetimes and "i_" means dummy times.
# Internally, they are all represented as double-precision numbers in GMT,
# relative to TIME_EPOCH with the unit defined by TIME_UNIT.
# In GMT, TIME_UNIT can be 'y' (year), 'o' (month), 'w' (week), 'd' (day),
# 'h' (hour), 'm' (minute), 's' (second). Years are 365.2425 days and months
# are of equal length.
# pd.to_timedelta() supports unit of 'W'/'D'/'h'/'m'/'s'/'ms'/'us'/'ns'.
match time_unit := lib.get_default("TIME_UNIT"):
case "y":
unit = "s"
scale = 365.2425 * 86400.0
case "o":
unit = "s"
scale = 365.2425 / 12.0 * 86400.0
case "w" | "d" | "h" | "m" | "s":
unit = time_unit.upper() if time_unit in "wd" else time_unit
scale = 1.0

columns = result.columns[2:4]
result[columns] *= scale
result[columns] = result[columns].apply(pd.to_timedelta, unit=unit)
if columns[0][0] == "t": # "t" or "i":
result[columns] += pd.Timestamp(lib.get_default("TIME_EPOCH"))
return result
108 changes: 81 additions & 27 deletions pygmt/tests/test_x2sys_cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import pandas as pd
import pytest
from packaging.version import Version
from pygmt import x2sys_cross, x2sys_init
from pygmt import config, x2sys_cross, x2sys_init
from pygmt.clib import __gmt_version__
from pygmt.datasets import load_sample_data
from pygmt.exceptions import GMTInvalidInput
Expand Down Expand Up @@ -52,15 +52,20 @@ def test_x2sys_cross_input_file_output_file():
output = x2sys_cross(
tracks=["@tut_ship.xyz"], tag=tag, coe="i", outfile=outfile
)

assert output is None # check that output is None since outfile is set
assert outfile.stat().st_size > 0 # check that outfile exists at path
_ = pd.read_csv(outfile, sep="\t", header=2) # ensure ASCII text file loads ok
result = pd.read_csv(outfile, sep="\t", comment=">", header=2)
assert result.shape == (14374, 12) if sys.platform == "darwin" else (14338, 12)
columns = list(result.columns)
assert columns[:6] == ["# x", "y", "i_1", "i_2", "dist_1", "dist_2"]
assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"]
npt.assert_allclose(result["i_1"].min(), 45.2099, rtol=1.0e-4)
npt.assert_allclose(result["i_1"].max(), 82945.9370, rtol=1.0e-4)


@pytest.mark.usefixtures("mock_x2sys_home")
@pytest.mark.xfail(
condition=Version(__gmt_version__) < Version("6.5.0") or sys.platform == "darwin",
condition=Version(__gmt_version__) < Version("6.5.0"),
reason="Upstream bug fixed in https://github.com/GenericMappingTools/gmt/pull/8188",
)
def test_x2sys_cross_input_file_output_dataframe():
Expand All @@ -74,39 +79,70 @@ def test_x2sys_cross_input_file_output_dataframe():
output = x2sys_cross(tracks=["@tut_ship.xyz"], tag=tag, coe="i")

assert isinstance(output, pd.DataFrame)
assert output.shape == (14338, 12)
assert output.shape == (14374, 12) if sys.platform == "darwin" else (14338, 12)
columns = list(output.columns)
assert columns[:6] == ["x", "y", "i_1", "i_2", "dist_1", "dist_2"]
assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"]
assert output["i_1"].dtype.type == np.timedelta64
assert output["i_2"].dtype.type == np.timedelta64
npt.assert_allclose(output["i_1"].min().total_seconds(), 45.2099, rtol=1.0e-4)
npt.assert_allclose(output["i_1"].max().total_seconds(), 82945.937, rtol=1.0e-4)


@pytest.mark.benchmark
@pytest.mark.usefixtures("mock_x2sys_home")
def test_x2sys_cross_input_dataframe_output_dataframe(tracks):
@pytest.mark.parametrize("unit", ["s", "o", "y"])
def test_x2sys_cross_input_dataframe_output_dataframe(tracks, unit):
"""
Run x2sys_cross by passing in one dataframe, and output internal crossovers to a
pandas.DataFrame.
pandas.DataFrame, checking TIME_UNIT s (second), o (month), and y (year).
"""
with TemporaryDirectory(prefix="X2SYS", dir=Path.cwd()) as tmpdir:
tag = Path(tmpdir).name
x2sys_init(tag=tag, fmtfile="xyz", force=True)

output = x2sys_cross(tracks=tracks, tag=tag, coe="i")
with config(TIME_UNIT=unit):
output = x2sys_cross(tracks=tracks, tag=tag, coe="i")

assert isinstance(output, pd.DataFrame)
assert output.shape == (14, 12)
columns = list(output.columns)
assert columns[:6] == ["x", "y", "i_1", "i_2", "dist_1", "dist_2"]
assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"]
assert output.dtypes["i_1"].type == np.object_
assert output.dtypes["i_2"].type == np.object_
assert output["i_1"].dtype.type == np.timedelta64
assert output["i_2"].dtype.type == np.timedelta64

# Scale to convert a value to second
match unit:
case "y":
scale = 365.2425 * 86400.0
case "o":
scale = 365.2425 / 12.0 * 86400.0
case _:
scale = 1.0
npt.assert_allclose(
output["i_1"].min().total_seconds(), 0.9175 * scale, rtol=1.0e-4
)
npt.assert_allclose(
output["i_1"].max().total_seconds(), 23.9996 * scale, rtol=1.0e-4
)


@pytest.mark.usefixtures("mock_x2sys_home")
def test_x2sys_cross_input_two_dataframes():
@pytest.mark.parametrize(
("unit", "epoch"),
[
("s", "1970-01-01T00:00:00"),
("o", "1970-01-01T00:00:00"),
("y", "1970-01-01T00:00:00"),
("s", "2012-03-04T05:06:07"),
],
)
def test_x2sys_cross_input_two_dataframes(unit, epoch):
"""
Run x2sys_cross by passing in two pandas.DataFrame tables with a time column, and
output external crossovers to a pandas.DataFrame.
output external crossovers to a pandas.DataFrame, checking TIME_UNIT s (second),
o (month), and y (year), and TIME_EPOCH 1970 and 2012.
"""
with TemporaryDirectory(prefix="X2SYS", dir=Path.cwd()) as tmpdir:
tmpdir_p = Path(tmpdir)
Expand All @@ -127,15 +163,22 @@ def test_x2sys_cross_input_two_dataframes():
track["time"] = pd.date_range(start=f"2020-{i}1-01", periods=10, freq="min")
tracks.append(track)

output = x2sys_cross(tracks=tracks, tag=tag, coe="e")
with config(TIME_UNIT=unit, TIME_EPOCH=epoch):
output = x2sys_cross(tracks=tracks, tag=tag, coe="e")

assert isinstance(output, pd.DataFrame)
assert output.shape == (26, 12)
columns = list(output.columns)
assert columns[:6] == ["x", "y", "t_1", "t_2", "dist_1", "dist_2"]
assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"]
assert output.dtypes["t_1"].type == np.datetime64
assert output.dtypes["t_2"].type == np.datetime64
assert output["t_1"].dtype.type == np.datetime64
assert output["t_2"].dtype.type == np.datetime64

tolerance = pd.Timedelta("1ms")
t1_min = pd.Timestamp("2020-01-01 00:00:10.6677")
t1_max = pd.Timestamp("2020-01-01 00:08:29.8067")
assert abs(output["t_1"].min() - t1_min) < tolerance
assert abs(output["t_1"].max() - t1_max) < tolerance


@pytest.mark.usefixtures("mock_x2sys_home")
Expand All @@ -159,8 +202,8 @@ def test_x2sys_cross_input_dataframe_with_nan(tracks):
columns = list(output.columns)
assert columns[:6] == ["x", "y", "i_1", "i_2", "dist_1", "dist_2"]
assert columns[6:] == ["head_1", "head_2", "vel_1", "vel_2", "z_X", "z_M"]
assert output.dtypes["i_1"].type == np.object_
assert output.dtypes["i_2"].type == np.object_
assert output.dtypes["i_1"].type == np.timedelta64
assert output.dtypes["i_2"].type == np.timedelta64


@pytest.mark.usefixtures("mock_x2sys_home")
Expand Down Expand Up @@ -201,7 +244,7 @@ def test_x2sys_cross_invalid_tracks_input_type(tracks):

@pytest.mark.usefixtures("mock_x2sys_home")
@pytest.mark.xfail(
condition=Version(__gmt_version__) < Version("6.5.0") or sys.platform == "darwin",
condition=Version(__gmt_version__) < Version("6.5.0"),
reason="Upstream bug fixed in https://github.com/GenericMappingTools/gmt/pull/8188",
)
def test_x2sys_cross_region_interpolation_numpoints():
Expand All @@ -222,15 +265,21 @@ def test_x2sys_cross_region_interpolation_numpoints():
)

assert isinstance(output, pd.DataFrame)
assert output.shape == (3882, 12)
# Check crossover errors (z_X) and mean value of observables (z_M)
npt.assert_allclose(output.z_X.mean(), -138.66, rtol=1e-4)
npt.assert_allclose(output.z_M.mean(), -2896.875915)
if sys.platform == "darwin":
assert output.shape == (3894, 12)
# Check crossover errors (z_X) and mean value of observables (z_M)
npt.assert_allclose(output.z_X.mean(), -138.23215, rtol=1e-4)
npt.assert_allclose(output.z_M.mean(), -2897.187545, rtol=1e-4)
else:
assert output.shape == (3882, 12)
# Check crossover errors (z_X) and mean value of observables (z_M)
npt.assert_allclose(output.z_X.mean(), -138.66, rtol=1e-4)
npt.assert_allclose(output.z_M.mean(), -2896.875915, rtol=1e-4)


@pytest.mark.usefixtures("mock_x2sys_home")
@pytest.mark.xfail(
condition=Version(__gmt_version__) < Version("6.5.0") or sys.platform == "darwin",
condition=Version(__gmt_version__) < Version("6.5.0"),
reason="Upstream bug fixed in https://github.com/GenericMappingTools/gmt/pull/8188",
)
def test_x2sys_cross_trackvalues():
Expand All @@ -243,7 +292,12 @@ def test_x2sys_cross_trackvalues():
output = x2sys_cross(tracks=["@tut_ship.xyz"], tag=tag, trackvalues=True)

assert isinstance(output, pd.DataFrame)
assert output.shape == (14338, 12)
# Check mean of track 1 values (z_1) and track 2 values (z_2)
npt.assert_allclose(output.z_1.mean(), -2422.418556, rtol=1e-4)
npt.assert_allclose(output.z_2.mean(), -2402.268364, rtol=1e-4)
if sys.platform == "darwin":
assert output.shape == (14374, 12)
# Check mean of track 1 values (z_1) and track 2 values (z_2)
npt.assert_allclose(output.z_1.mean(), -2422.973372, rtol=1e-4)
npt.assert_allclose(output.z_2.mean(), -2402.87476, rtol=1e-4)
else:
assert output.shape == (14338, 12)
npt.assert_allclose(output.z_1.mean(), -2422.418556, rtol=1e-4)
npt.assert_allclose(output.z_2.mean(), -2402.268364, rtol=1e-4)

0 comments on commit 844594f

Please sign in to comment.