Skip to content

Commit

Permalink
[XLA] HLO Pass Unit Testing: Add RunAndFilecheckHloRewrite API vers…
Browse files Browse the repository at this point in the history
…ion that takes HLO module with interleaved // CEHCK lines as an input.

PiperOrigin-RevId: 718665282
  • Loading branch information
abhigunj authored and Google-ML-Automation committed Jan 23, 2025
1 parent 0cf6cd7 commit dbe7ade
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 30 deletions.
3 changes: 0 additions & 3 deletions xla/hlo/testlib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,6 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:errors",
"@tsl//tsl/platform:logging",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)
Expand Down
12 changes: 8 additions & 4 deletions xla/hlo/testlib/hlo_hardware_independent_test_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,6 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/logging.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace xla {

Expand Down Expand Up @@ -216,6 +212,14 @@ void HloHardwareIndependentTestBase::RunAndFilecheckHloRewrite(
}
}

void HloHardwareIndependentTestBase::RunAndFilecheckHloRewrite(
absl::string_view hlo_with_checks, HloPassInterface&& hlo_pass,
std::function<void(HloModule*)> after_pass_checks,
const HloModuleConfig* config) const {
RunAndFilecheckHloRewrite(hlo_with_checks, std::move(hlo_pass),
hlo_with_checks, after_pass_checks, config);
}

void HloHardwareIndependentTestBase::RunAndFilecheckHloModuleGroupRewrite(
absl::Span<const absl::string_view> hlo_module_strs,
HloPassInterface&& hlo_pass,
Expand Down
12 changes: 11 additions & 1 deletion xla/hlo/testlib/hlo_hardware_independent_test_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,23 @@ class HloHardwareIndependentTestBase : public ::testing::Test {
// generally used for CPU:AOT compilation.
static void SetAotFastMathDebugOptions(DebugOptions* options);

// Runs pass `hlo_pass` on input HLO module `hlo` with optional config, and
// FileChecks the result against interleaved expected `CHECK` directives.
//
// If the rewrite has changed the module, also runs `additional_checks` on the
// result.
void RunAndFilecheckHloRewrite(
absl::string_view hlo_with_checks, HloPassInterface&& hlo_pass,
std::function<void(HloModule*)> after_pass_checks = nullptr,
const HloModuleConfig* config = nullptr) const;

// Runs pass `hlo_pass` on input HLO module `hlo` with optional config, and
// FileChecks the result against `expected`.
//
// If the rewrite has changed the module, also runs `additional_checks` on the
// result.
void RunAndFilecheckHloRewrite(
absl::string_view hlo, HloPassInterface&& hlo_pass,
absl::string_view hlo_with_filecheck_lines, HloPassInterface&& hlo_pass,
std::optional<absl::string_view> expected,
std::function<void(HloModule*)> after_pass_checks = nullptr,
const HloModuleConfig* config = nullptr) const;
Expand Down
12 changes: 0 additions & 12 deletions xla/hlo/transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -511,20 +511,8 @@ xla_cc_test(
srcs = ["add_original_value_test.cc"],
deps = [
":add_original_value",
"//xla:shape_util",
"//xla:window_util",
"//xla:xla_data_proto_cc",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"//xla/hlo/testlib:pattern_matcher_gmock",
"//xla/hlo/testlib:verified_hlo_module",
"//xla/service:pattern_matcher",
"//xla/tsl/lib/core:status_test_util",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_googletest//:gtest_main",
"@tsl//tsl/platform:status_matchers",
"@tsl//tsl/platform:statusor",
"@tsl//tsl/platform:test",
],
)
16 changes: 6 additions & 10 deletions xla/hlo/transforms/add_original_value_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ limitations under the License.
#include <gtest/gtest.h>
#include "absl/strings/string_view.h"
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
#include "tsl/platform/statusor.h"
#include "tsl/platform/test.h"

namespace xla {
namespace {
Expand Down Expand Up @@ -79,22 +77,20 @@ CHECK: ((f32[], f32[3]{0}), f32[2,3]{1,0}) tuple(%[[TUPLE]], %[[V3]]), origin={
TEST_F(AddOriginalValueTest, GetTupleElement) {
constexpr absl::string_view hlo_string = R"(
HloModule test, entry_computation_layout={()->s32[2,3]{1,0}}
// CHECK-LABEL: test
ENTRY test {
// CHECK: %[[CONSTANT1:.*]] = f32[3]{0} constant({1, 2, 3}), origin={{[{]}}{"[[CONSTANT1]]"}
constant = f32[3]{0} constant({1, 2, 3})
// CHECK-NEXT: %[[CONSTANT2:.*]] = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } }), origin={{[{]}}{"[[CONSTANT2]]"}
constant.1 = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } })
// CHECK-NEXT: %[[TUPLE:.*]] = (f32[3]{0}, s32[2,3]{1,0}) tuple(%[[CONSTANT1]], %[[CONSTANT2]]), origin={({"[[CONSTANT1]]"}, {"[[CONSTANT2]]"})}
tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} constant, s32[2,3]{1,0} constant.1)
// CHECK-NEXT: s32[2,3]{1,0} get-tuple-element(%[[TUPLE]]), index=1, origin={{[{]}}{"[[CONSTANT2]]"}
ROOT get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) tuple), index=1
}
)";

RunAndFilecheckHloRewrite(hlo_string, AddOriginalValue(), R"(
CHECK: %[[CONSTANT1:.*]] = f32[3]{0} constant({1, 2, 3}), origin={{[{]}}{"[[CONSTANT1]]"}
CHECK: %[[CONSTANT2:.*]] = s32[2,3]{1,0} constant({ { 1, 2, 3 }, { 4, 5, 6 } }), origin={{[{]}}{"[[CONSTANT2]]"}
CHECK: %[[TUPLE:.*]] = (f32[3]{0}, s32[2,3]{1,0}) tuple(%[[CONSTANT1]], %[[CONSTANT2]]), origin={({"[[CONSTANT1]]"}, {"[[CONSTANT2]]"})}
CHECK: s32[2,3]{1,0} get-tuple-element(%[[TUPLE]]), index=1, origin={{[{]}}{"[[CONSTANT2]]"}
)");
RunAndFilecheckHloRewrite(hlo_string, AddOriginalValue());
}

TEST_F(AddOriginalValueTest, GetTupleElementNonSymbolic) {
Expand Down

0 comments on commit dbe7ade

Please sign in to comment.