Skip to content

Commit

Permalink
Allow pandas.Series inputs to fig.histogram and pygmt.info (#1329)
Browse files Browse the repository at this point in the history
Let 1D pandas.Series inputs work properly by modifying
the virtualfile_from_data function. Also added two tests
in test_histogram.py and test_info.py to ensure this works.
  • Loading branch information
weiji14 authored Jun 20, 2021
1 parent 7f37e1c commit d9df659
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 7 deletions.
6 changes: 4 additions & 2 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,9 +1449,11 @@ def virtualfile_from_data(
_data.extend(extra_arrays)
elif kind == "matrix": # turn 2D arrays into list of vectors
try:
# pandas.Series will be handled below like a 1d numpy ndarray
assert not hasattr(data, "to_frame")
# pandas.DataFrame and xarray.Dataset types
_data = [array for _, array in data.items()]
except AttributeError:
except (AttributeError, AssertionError):
try:
# Just use virtualfile_from_matrix for 2D numpy.ndarray
# which are signed integer (i), unsigned integer (u) or
Expand All @@ -1460,7 +1462,7 @@ def virtualfile_from_data(
_virtualfile_from = self.virtualfile_from_matrix
_data = (data,)
except (AssertionError, AttributeError):
# Python lists, tuples, and numpy ndarray types
# Python list, tuple, numpy ndarray and pandas.Series types
_data = np.atleast_2d(np.asanyarray(data).T)

# Finally create the virtualfile from the data, to be passed into GMT
Expand Down
12 changes: 7 additions & 5 deletions pygmt/tests/test_histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,24 @@
"""
Tests histogram.
"""
import pandas as pd
import pytest
from pygmt import Figure


@pytest.fixture(scope="module")
def table():
@pytest.fixture(scope="module", name="table", params=[list, pd.Series])
def fixture_table(request):
"""
Returns a list of integers to be used in the histogram.
"""
return [1, 1, 1, 1, 1, 1, 2, 2, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8, 8, 8]
data = [1, 1, 1, 1, 1, 1, 2, 2, 2, 3, 4, 5, 6, 7, 8, 8, 8, 8, 8, 8]
return request.param(data)


@pytest.mark.mpl_image_compare
@pytest.mark.mpl_image_compare(filename="test_histogram.png")
def test_histogram(table):
"""
Tests plotting a histogram using a list of integers.
Tests plotting a histogram using a sequence of integers from a table.
"""
fig = Figure()
fig.histogram(
Expand Down
9 changes: 9 additions & 0 deletions pygmt/tests/test_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,15 @@ def test_info_2d_list():
assert output == expected_output


def test_info_series():
"""
Make sure info works on a pandas.Series input.
"""
output = info(pd.Series(data=[0, 4, 2, 8, 6]))
expected_output = "<vector memory>: N = 5 <0/8>\n"
assert output == expected_output


def test_info_dataframe():
"""
Make sure info works on pandas.DataFrame inputs.
Expand Down

0 comments on commit d9df659

Please sign in to comment.