Skip to content

Commit

Permalink
Update census example to use RaggedFeature (and resolve related TODO).
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 489961032
  • Loading branch information
zoyahav authored and tfx-copybara committed Nov 21, 2022
1 parent ab665ad commit 8006f96
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 29 deletions.
30 changes: 14 additions & 16 deletions examples/census_example_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,16 @@
]


RAW_DATA_FEATURE_SPEC = dict([(name, tf.io.FixedLenFeature([], tf.string))
for name in CATEGORICAL_FEATURE_KEYS] +
[(name, tf.io.FixedLenFeature([], tf.float32))
for name in NUMERIC_FEATURE_KEYS] +
[(name, tf.io.VarLenFeature(tf.float32))
for name in OPTIONAL_NUMERIC_FEATURE_KEYS] +
[(LABEL_KEY,
tf.io.FixedLenFeature([], tf.string))])
RAW_DATA_FEATURE_SPEC = dict(
[(name, tf.io.FixedLenFeature([], tf.string))
for name in CATEGORICAL_FEATURE_KEYS] +
[(name, tf.io.FixedLenFeature([], tf.float32))
for name in NUMERIC_FEATURE_KEYS] +
[(name, # pylint: disable=g-complex-comprehension
tf.io.RaggedFeature(
tf.float32, value_key=name, partitions=[], row_splits_dtype=tf.int64))
for name in OPTIONAL_NUMERIC_FEATURE_KEYS] +
[(LABEL_KEY, tf.io.FixedLenFeature([], tf.string))])

_SCHEMA = tft.DatasetMetadata.from_feature_spec(RAW_DATA_FEATURE_SPEC).schema

Expand Down Expand Up @@ -121,14 +123,10 @@ def preprocessing_fn(inputs):
outputs[key] = tft.scale_to_0_1(inputs[key])

for key in OPTIONAL_NUMERIC_FEATURE_KEYS:
# This is a SparseTensor because it is optional. Here we fill in a default
# value when it is missing.
sparse = tf.sparse.SparseTensor(inputs[key].indices, inputs[key].values,
[inputs[key].dense_shape[0], 1])
dense = tf.sparse.to_dense(sp_input=sparse, default_value=0.)
# Reshaping from a batch of vectors of size 1 to a batch to scalars.
dense = tf.squeeze(dense, axis=1)
outputs[key] = tft.scale_to_0_1(dense)
# This is a RaggedTensor because it is optional. Here we fill in a default
# value when it is missing, after scaling it.
outputs[key] = tft.scale_to_0_1(inputs[key]).to_tensor(
default_value=0., shape=[None, 1])

# For all categorical columns except the label column, we generate a
# vocabulary, and convert the string feature to a one-hot encoding.
Expand Down
10 changes: 5 additions & 5 deletions examples/census_example_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@

import os

import tensorflow as tf
import census_example
import census_example_common
from tensorflow_transform import test_case
import local_model_server


class CensusExampleTest(tf.test.TestCase):
class CensusExampleTest(test_case.TransformTestCase):

def testCensusExampleAccuracy(self):
raw_data_dir = os.path.join(os.path.dirname(__file__), 'testdata/census')
Expand Down Expand Up @@ -106,13 +106,13 @@ def testCensusExampleAccuracy(self):
}"""
results = local_model_server.make_classification_request(
address, ascii_classification_request)
self.assertEqual(len(results), 1)
self.assertEqual(len(results[0].classes), 2)
self.assertLen(results, 1)
self.assertLen(results[0].classes, 2)
self.assertEqual(results[0].classes[0].label, '0')
self.assertLess(results[0].classes[0].score, 0.01)
self.assertEqual(results[0].classes[1].label, '1')
self.assertGreater(results[0].classes[1].score, 0.99)


if __name__ == '__main__':
tf.test.main()
test_case.main()
12 changes: 4 additions & 8 deletions examples/census_example_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,10 @@ def transform_dataset(data):
for key, val in data.items():
if key not in common.RAW_DATA_FEATURE_SPEC:
continue
if isinstance(common.RAW_DATA_FEATURE_SPEC[key], tf.io.VarLenFeature):
# TODO(b/169666856): Remove conversion to sparse once ragged tensors are
# natively supported.
if isinstance(common.RAW_DATA_FEATURE_SPEC[key], tf.io.RaggedFeature):
# make_csv_dataset will set the value to 0 when it's missing.
raw_features[key] = tf.RaggedTensor.from_tensor(
tf.expand_dims(val, -1)).to_sparse()
tf.expand_dims(val, axis=-1), padding=0)
continue
raw_features[key] = val
transformed_features = tft_layer(raw_features)
Expand Down Expand Up @@ -189,10 +188,7 @@ def train_and_evaluate(raw_train_eval_data_path_pattern,

inputs = {}
for key, spec in feature_spec.items():
if isinstance(spec, tf.io.VarLenFeature):
inputs[key] = tf.keras.layers.Input(
shape=[None], name=key, dtype=spec.dtype, sparse=True)
elif isinstance(spec, tf.io.FixedLenFeature):
if isinstance(spec, tf.io.FixedLenFeature):
# TODO(b/208879020): Move into schema such that spec.shape is [1] and not
# [] for scalars.
inputs[key] = tf.keras.layers.Input(
Expand Down

1 comment on commit 8006f96

@pritamdodeja
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had submitted pr270 related to this, and it appears to be still pending review. Can you please provide me feedback on that pr? Also, what feature is represented by a RaggedTensor in this dataset? Thank you!

Please sign in to comment.