Skip to content

Commit

Permalink
Simplify flag parsing in flag_types.cc using fixed_option_set_flag.
Browse files Browse the repository at this point in the history
Also extend fixed_option_set_flag to support aliases and case-insensitive flags if desired.

PiperOrigin-RevId: 724544586
  • Loading branch information
Google-ML-Automation committed Feb 8, 2025
1 parent 6af6baa commit 6818964
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 9 deletions.
40 changes: 31 additions & 9 deletions xla/tsl/util/fixed_option_set_flag.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,23 @@ limitations under the License.

#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"

namespace xla {

// Controls FixedOptionSetFlagParser's behavior.
struct FixedOptionSetFlagParserConfig {
// If true, allows aliases for flag options. The first option listed for a
// given name takes precedence when unparsing.
bool allow_aliases = false;
// Whether the flag values are case sensitive. It's a bad practice to have
// case-insensitive flag values. DO NOT SET THIS FIELD TO FALSE IN NEW CODE.
bool case_sensitive_do_not_use_in_new_code = true;
};

// A parser for a flag of type T that takes a fixed set of options. This makes
// it easier and safer to define flags that take a fixed set of options.
// Requires T to support equality comparison, hashing, and conversion to
Expand Down Expand Up @@ -75,16 +86,20 @@ class FixedOptionSetFlagParser {
// Creates a parser for a flag of type T that takes a fixed set of options.
// The options must be valid, i.e., there must be no duplicate names or
// values.
explicit FixedOptionSetFlagParser(const std::vector<FlagOption>& options)
: options_(ValidateFlagOptionsOrDie(options)) {}
explicit FixedOptionSetFlagParser(
const std::vector<FlagOption>& options,
const FixedOptionSetFlagParserConfig& config)
: options_(ValidateFlagOptionsOrDie(options, config)),
case_sensitive_(config.case_sensitive_do_not_use_in_new_code) {}

// Parses the flag from the given text. Returns true if the text is
// valid, and sets the value to the corresponding option. Otherwise, returns
// false and sets the error message.
[[nodiscard]] bool Parse(absl::string_view text, T* value,
std::string* error) const {
for (const auto& option : options_) {
if (text == option.name) {
if ((case_sensitive_ && text == option.name) ||
(!case_sensitive_ && absl::EqualsIgnoreCase(text, option.name))) {
*value = option.value;
return true;
}
Expand Down Expand Up @@ -117,22 +132,27 @@ class FixedOptionSetFlagParser {
// Validates the flag options and returns them. Dies if the options are not
// valid.
static std::vector<FlagOption> ValidateFlagOptionsOrDie(
const std::vector<FlagOption>& options) {
const std::vector<FlagOption>& options,
const FixedOptionSetFlagParserConfig& config) {
// Check that the same name or value is not used multiple times.
absl::flat_hash_set<std::string> names;
absl::flat_hash_set<T> values;
for (const auto& option : options) {
CHECK(!names.contains(option.name))
<< "Duplicate flag option name: " << option.name;
CHECK(!values.contains(option.value))
<< "Duplicate flag option value: " << absl::StrCat(option.value);
names.insert(option.name);
values.insert(option.value);

if (!config.allow_aliases) {
CHECK(!values.contains(option.value))
<< "Duplicate flag option value: " << absl::StrCat(option.value);
values.insert(option.value);
}
}
return options;
}

const std::vector<FlagOption> options_;
const bool case_sensitive_ = true;
};

// Returns the parser for a flag of type T that takes a fixed set of options.
Expand All @@ -146,12 +166,14 @@ class FixedOptionSetFlagParser {
template <typename T>
[[nodiscard]] const FixedOptionSetFlagParser<T>& GetFixedOptionSetFlagParser(
const std::vector<typename FixedOptionSetFlagParser<T>::FlagOption>&
options) {
options,
const FixedOptionSetFlagParserConfig& config = {}) {
// Per Google C++ style guide, we use a function-local static
// variable to ensure that the parser is only created once and never
// destroyed. We cannot use absl::NoDestructor here because it is not
// available in the version of Abseil that openxla uses.
static const auto* const parser = new FixedOptionSetFlagParser<T>(options);
static const auto* const parser =
new FixedOptionSetFlagParser<T>(options, config);
return *parser;
}

Expand Down
106 changes: 106 additions & 0 deletions xla/tsl/util/fixed_option_set_flag_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,111 @@ TEST(FixedOptionSetFlag, UnparseFailsForInvalidOptions) {
EXPECT_EQ(AbslUnparseFlag(static_cast<Foo>(123)), "123");
}

enum class FooWithAliases {
kBar,
kBaz,
};

static const FixedOptionSetFlagParser<FooWithAliases>&
GetFooWithAliasesParser() {
static const auto& parser = GetFixedOptionSetFlagParser<FooWithAliases>(
{
{"bar", FooWithAliases::kBar, "the first option"},
// "baz" and "baz2" are aliases for the same option. The first one
// listed takes precedence when unparsing.
{"baz", FooWithAliases::kBaz},
{"baz2", FooWithAliases::kBaz},
},
// Cannot use designated initializers here because tensorflow needs to
// support C++17.
{/*allow_aliases=*/true});
return parser;
}

bool AbslParseFlag(absl::string_view text, FooWithAliases* foo,
std::string* error) {
return GetFooWithAliasesParser().Parse(text, foo, error);
}

std::string AbslUnparseFlag(FooWithAliases foo) {
return GetFooWithAliasesParser().Unparse(foo);
}

TEST(FixedOptionSetFlag, ParseSucceedsForValidOptionsWithAliases) {
FooWithAliases foo;
std::string error;
ASSERT_TRUE(AbslParseFlag("bar", &foo, &error));
EXPECT_EQ(foo, FooWithAliases::kBar);
ASSERT_TRUE(AbslParseFlag("baz", &foo, &error));
EXPECT_EQ(foo, FooWithAliases::kBaz);
ASSERT_TRUE(AbslParseFlag("baz2", &foo, &error));
EXPECT_EQ(foo, FooWithAliases::kBaz);
}

TEST(FixedOptionSetFlag, UnparseSucceedsForValidOptionsWithAliases) {
EXPECT_EQ(AbslUnparseFlag(FooWithAliases::kBar), "bar");
EXPECT_EQ(AbslUnparseFlag(FooWithAliases::kBaz), "baz");
}

TEST(FixedOptionSetFlag, ParseFailsForInvalidOptionsWithAliases) {
FooWithAliases foo;
std::string error;
ASSERT_FALSE(AbslParseFlag("baz3", &foo, &error));
EXPECT_EQ(error,
"Unrecognized flag option: baz3. Valid options are: bar (the first "
"option), baz, baz2.");
}

enum class FooCaseInsensitive {
kBar,
kBaz,
};

static const FixedOptionSetFlagParser<FooCaseInsensitive>&
GetFooCaseInsensitiveParser() {
static const auto& parser = GetFixedOptionSetFlagParser<FooCaseInsensitive>(
{
{"bar", FooCaseInsensitive::kBar, "the first option"},
{"baz", FooCaseInsensitive::kBaz},
},
// Cannot use designated initializers here because tensorflow needs to
// support C++17.
{/*allow_aliases=*/false,
/*case_sensitive_do_not_use_in_new_code=*/false});
return parser;
}

bool AbslParseFlag(absl::string_view text, FooCaseInsensitive* foo,
std::string* error) {
return GetFooCaseInsensitiveParser().Parse(text, foo, error);
}

std::string AbslUnparseFlag(FooCaseInsensitive foo) {
return GetFooCaseInsensitiveParser().Unparse(foo);
}

TEST(FixedOptionSetFlag, ParseSucceedsForValidOptionsCaseInsensitive) {
FooCaseInsensitive foo;
std::string error;
ASSERT_TRUE(AbslParseFlag("BaR", &foo, &error));
EXPECT_EQ(foo, FooCaseInsensitive::kBar);
ASSERT_TRUE(AbslParseFlag("bAz", &foo, &error));
EXPECT_EQ(foo, FooCaseInsensitive::kBaz);
}

TEST(FixedOptionSetFlag, UnparseSucceedsForValidOptionsCaseInsensitive) {
EXPECT_EQ(AbslUnparseFlag(FooCaseInsensitive::kBar), "bar");
EXPECT_EQ(AbslUnparseFlag(FooCaseInsensitive::kBaz), "baz");
}

TEST(FixedOptionSetFlag, ParseFailsForInvalidOptionsCaseInsensitive) {
FooCaseInsensitive foo;
std::string error;
ASSERT_FALSE(AbslParseFlag("foo", &foo, &error));
EXPECT_EQ(error,
"Unrecognized flag option: foo. Valid options are: bar (the first "
"option), baz.");
}

} // namespace
} // namespace xla

0 comments on commit 6818964

Please sign in to comment.