Skip to content

Commit

Permalink
Use arrow utilities from tfx_bsl
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 273571365
  • Loading branch information
paulgc authored and tf-data-validation-team committed Oct 8, 2019
1 parent b9f060a commit bf40237
Show file tree
Hide file tree
Showing 21 changed files with 41 additions and 1,500 deletions.
17 changes: 4 additions & 13 deletions tensorflow_data_validation/arrow/arrow_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,9 @@
import numpy as np
from tensorflow_data_validation import types
from tensorflow_data_validation.pyarrow_tf import pyarrow as pa
from tensorflow_data_validation.pywrap import pywrap_tensorflow_data_validation as pywrap
from tfx_bsl.arrow import array_util
from typing import Iterable, Optional, Text, Tuple

# The following are function aliases thus valid function names.
# pylint: disable=invalid-name
ListLengthsFromListArray = pywrap.TFDV_Arrow_ListLengthsFromListArray
GetFlattenedArrayParentIndices = pywrap.TFDV_Arrow_GetFlattenedArrayParentIndices
GetArrayNullBitmapAsByteArray = pywrap.TFDV_Arrow_GetArrayNullBitmapAsByteArray
GetBinaryArrayTotalByteSize = pywrap.TFDV_Arrow_GetBinaryArrayTotalByteSize
ValueCounts = pywrap.TFDV_Arrow_ValueCounts
MakeListArrayFromParentIndicesAndValues = (
pywrap.TFDV_Arrow_MakeListArrayFromParentIndicesAndValues)


def _get_weight_feature(input_table: pa.Table,
weight_feature: Text) -> np.ndarray:
Expand All @@ -58,7 +48,7 @@ def _get_weight_feature(input_table: pa.Table,
'table.'.format(weight_feature))

# Before flattening, check that there is a single value for each example.
weight_lengths = ListLengthsFromListArray(weights).to_numpy()
weight_lengths = array_util.ListLengthsFromListArray(weights).to_numpy()
if not np.all(weight_lengths == 1):
raise ValueError(
'Weight feature "{}" must have exactly one value in each example.'
Expand Down Expand Up @@ -148,7 +138,8 @@ def _recursion_helper(
flat_struct_array = array.flatten()
flat_weights = None
if weights is not None:
flat_weights = weights[GetFlattenedArrayParentIndices(array).to_numpy()]
flat_weights = weights[
array_util.GetFlattenedArrayParentIndices(array).to_numpy()]
for field in flat_struct_array.type:
field_name = field.name
# use "yield from" after PY 3.3.
Expand Down
207 changes: 0 additions & 207 deletions tensorflow_data_validation/arrow/arrow_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,220 +20,13 @@
import itertools

from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import six
from tensorflow_data_validation import types
from tensorflow_data_validation.arrow import arrow_util
from tensorflow_data_validation.pyarrow_tf import pyarrow as pa


class ArrowUtilTest(absltest.TestCase):

def test_invalid_input_type(self):

functions_expecting_list_array = [
arrow_util.ListLengthsFromListArray,
arrow_util.GetFlattenedArrayParentIndices,
]
functions_expecting_array = [arrow_util.GetArrayNullBitmapAsByteArray]
functions_expecting_binary_array = [arrow_util.GetBinaryArrayTotalByteSize]
for f in itertools.chain(functions_expecting_list_array,
functions_expecting_array,
functions_expecting_binary_array):
with self.assertRaisesRegex(RuntimeError, "Could not unwrap Array"):
f(1)

for f in functions_expecting_list_array:
with self.assertRaisesRegex(RuntimeError, "Expected ListArray but got"):
f(pa.array([1, 2, 3]))

for f in functions_expecting_binary_array:
with self.assertRaisesRegex(RuntimeError, "Expected BinaryArray"):
f(pa.array([[1, 2, 3]]))

def test_list_lengths(self):
list_lengths = arrow_util.ListLengthsFromListArray(
pa.array([], type=pa.list_(pa.int64())))
self.assertTrue(list_lengths.equals(pa.array([], type=pa.int32())))
list_lengths = arrow_util.ListLengthsFromListArray(
pa.array([[1., 2.], [], [3.]]))
self.assertTrue(list_lengths.equals(pa.array([2, 0, 1], type=pa.int32())))
list_lengths = arrow_util.ListLengthsFromListArray(
pa.array([[1., 2.], None, [3.]]))
self.assertTrue(list_lengths.equals(pa.array([2, 0, 1], type=pa.int32())))

def test_get_array_null_bitmap_as_byte_array(self):
array = pa.array([], type=pa.int32())
null_masks = arrow_util.GetArrayNullBitmapAsByteArray(array)
self.assertTrue(null_masks.equals(pa.array([], type=pa.uint8())))

array = pa.array([1, 2, None, 3, None], type=pa.int32())
null_masks = arrow_util.GetArrayNullBitmapAsByteArray(array)
self.assertTrue(
null_masks.equals(pa.array([0, 0, 1, 0, 1], type=pa.uint8())))

array = pa.array([1, 2, 3])
null_masks = arrow_util.GetArrayNullBitmapAsByteArray(array)
self.assertTrue(null_masks.equals(pa.array([0, 0, 0], type=pa.uint8())))

array = pa.array([None, None, None], type=pa.int32())
null_masks = arrow_util.GetArrayNullBitmapAsByteArray(array)
self.assertTrue(null_masks.equals(pa.array([1, 1, 1], type=pa.uint8())))
# Demonstrate that the returned array can be converted to a numpy boolean
# array w/o copying
np.testing.assert_equal(
np.array([True, True, True]), null_masks.to_numpy().view(np.bool))

def test_get_flattened_array_parent_indices(self):
indices = arrow_util.GetFlattenedArrayParentIndices(
pa.array([], type=pa.list_(pa.int32())))
self.assertTrue(indices.equals(pa.array([], type=pa.int32())))

indices = arrow_util.GetFlattenedArrayParentIndices(
pa.array([[1.], [2.], [], [3.]]))
self.assertTrue(indices.equals(pa.array([0, 1, 3], type=pa.int32())))

def test_get_binary_array_total_byte_size(self):
binary_array = pa.array([b"abc", None, b"def", b"", b"ghi"])
self.assertEqual(9, arrow_util.GetBinaryArrayTotalByteSize(binary_array))
sliced_1_2 = binary_array.slice(1, 2)
self.assertEqual(3, arrow_util.GetBinaryArrayTotalByteSize(sliced_1_2))
sliced_2 = binary_array.slice(2)
self.assertEqual(6, arrow_util.GetBinaryArrayTotalByteSize(sliced_2))

unicode_array = pa.array([u"abc"])
self.assertEqual(3, arrow_util.GetBinaryArrayTotalByteSize(unicode_array))

empty_array = pa.array([], type=pa.binary())
self.assertEqual(0, arrow_util.GetBinaryArrayTotalByteSize(empty_array))

def _value_counts_struct_array_to_dict(self, value_counts):
result = {}
for value_count in value_counts:
value_count = value_count.as_py()
result[value_count["values"]] = value_count["counts"]
return result

def test_value_counts_binary(self):
binary_array = pa.array([b"abc", b"ghi", b"def", b"ghi", b"ghi", b"def"])
expected_result = {b"abc": 1, b"ghi": 3, b"def": 2}
self.assertDictEqual(self._value_counts_struct_array_to_dict(
arrow_util.ValueCounts(binary_array)), expected_result)

def test_value_counts_integer(self):
int_array = pa.array([1, 4, 1, 3, 1, 4])
expected_result = {1: 3, 4: 2, 3: 1}
self.assertDictEqual(self._value_counts_struct_array_to_dict(
arrow_util.ValueCounts(int_array)), expected_result)

def test_value_counts_empty(self):
empty_array = pa.array([])
expected_result = {}
self.assertDictEqual(self._value_counts_struct_array_to_dict(
arrow_util.ValueCounts(empty_array)), expected_result)

_MAKE_LIST_ARRAY_INVALID_INPUT_TEST_CASES = [
dict(
testcase_name="invalid_parent_index",
num_parents=None,
parent_indices=np.array([0], dtype=np.int64),
values=pa.array([1]),
expected_error=RuntimeError,
expected_error_regexp="Expected integer"),
dict(
testcase_name="parent_indices_not_np",
num_parents=1,
parent_indices=[0],
values=pa.array([1]),
expected_error=TypeError,
expected_error_regexp="to be a numpy array"
),
dict(
testcase_name="parent_indices_not_1d",
num_parents=1,
parent_indices=np.array([[0]], dtype=np.int64),
values=pa.array([1]),
expected_error=TypeError,
expected_error_regexp="to be a 1-D int64 numpy array"
),
dict(
testcase_name="parent_indices_not_int64",
num_parents=1,
parent_indices=np.array([0], dtype=np.int32),
values=pa.array([1]),
expected_error=TypeError,
expected_error_regexp="to be a 1-D int64 numpy array"
),
dict(
testcase_name="parent_indices_length_not_equal_to_values_length",
num_parents=1,
parent_indices=np.array([0], dtype=np.int64),
values=pa.array([1, 2]),
expected_error=RuntimeError,
expected_error_regexp="values array and parent indices array must be of the same length"
),
dict(
testcase_name="num_parents_too_small",
num_parents=1,
parent_indices=np.array([1], dtype=np.int64),
values=pa.array([1]),
expected_error=RuntimeError,
expected_error_regexp="Found a parent index 1 while num_parents was 1"
)
]


_MAKE_LIST_ARRAY_TEST_CASES = [
dict(
testcase_name="parents_are_all_empty",
num_parents=5,
parent_indices=np.array([], dtype=np.int64),
values=pa.array([], type=pa.int64()),
expected=pa.array([None, None, None, None, None],
type=pa.list_(pa.int64()))),
dict(
testcase_name="long_num_parent",
num_parents=(long(1) if six.PY2 else 1),
parent_indices=np.array([0], dtype=np.int64),
values=pa.array([1]),
expected=pa.array([[1]])
),
dict(
testcase_name="leading nones",
num_parents=3,
parent_indices=np.array([2], dtype=np.int64),
values=pa.array([1]),
expected=pa.array([None, None, [1]]),
),
dict(
testcase_name="same_parent_and_holes",
num_parents=4,
parent_indices=np.array([0, 0, 0, 3, 3], dtype=np.int64),
values=pa.array(["a", "b", "c", "d", "e"]),
expected=pa.array([["a", "b", "c"], None, None, ["d", "e"]])
)
]


class MakeListArrayFromParentIndicesAndValuesTest(parameterized.TestCase):

@parameterized.named_parameters(*_MAKE_LIST_ARRAY_INVALID_INPUT_TEST_CASES)
def testInvalidInput(self, num_parents, parent_indices, values,
expected_error, expected_error_regexp):
with self.assertRaisesRegex(expected_error, expected_error_regexp):
arrow_util.MakeListArrayFromParentIndicesAndValues(
num_parents, parent_indices, values)

@parameterized.named_parameters(*_MAKE_LIST_ARRAY_TEST_CASES)
def testMakeListArray(self, num_parents, parent_indices, values, expected):
actual = arrow_util.MakeListArrayFromParentIndicesAndValues(
num_parents, parent_indices, values)
self.assertTrue(
actual.equals(expected),
"actual: {}, expected: {}".format(actual, expected))


class EnumerateArraysTest(absltest.TestCase):

def testInvalidWeightColumnMissingValue(self):
Expand Down
30 changes: 0 additions & 30 deletions tensorflow_data_validation/arrow/cc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,6 @@ cc_library(
],
)

cc_library(
name = "arrow_util",
srcs = ["arrow_util.cc"],
hdrs = ["arrow_util.h"],
deps = [
":common",
":init_numpy",
"@arrow",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@local_config_python//:python_headers",
],
)

cc_library(
name = "decoded_examples_to_arrow",
srcs = ["decoded_examples_to_arrow.cc"],
Expand All @@ -47,22 +33,6 @@ cc_library(
],
)

cc_library(
name = "merge",
srcs = ["merge.cc"],
hdrs = ["merge.h"],
deps = [
":common",
":init_numpy",
"@arrow",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:variant",
"@local_config_python//:numpy_headers",
"@local_config_python//:python_headers",
],
)

cc_library(
name = "init_numpy",
srcs = ["init_numpy.cc"],
Expand Down
Loading

0 comments on commit bf40237

Please sign in to comment.