Skip to content

Commit

Permalink
Do not fuse instructions inside custom fusions/calls
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721595219
  • Loading branch information
Google-ML-Automation committed Jan 31, 2025
1 parent 9a03576 commit 4bc8f20
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 0 deletions.
33 changes: 33 additions & 0 deletions xla/service/cpu/cpu_instruction_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,43 @@ bool CanBeOutputFusedIntoSomeOperand(const HloInstruction* consumer) {
(CanBeOutputFused(consumer->operand(0), consumer) ||
CanBeOutputFused(consumer->operand(1), consumer));
}

} // namespace

void CpuInstructionFusion::ComputeInstructionsToSkip(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) {
const auto computations_list =
module->MakeComputationPostOrder(execution_threads);
instructions_to_skip_.clear();
for (auto* computation : computations_list) {
for (auto* instruction : computation->MakeInstructionPostOrder()) {
if (instruction->IsCustomFusion() ||
instruction->opcode() == HloOpcode::kCustomCall) {
HloCallableInstruction* callable =
Cast<HloCallableInstruction>(instruction);
if (callable->called_computations().empty()) {
continue;
}
for (HloInstruction* instr :
callable->called_computation()->instructions())
instructions_to_skip_.insert(instr);
}
}
}
}

bool CpuInstructionFusion::ShouldSkip(const HloInstruction* inst) const {
return instructions_to_skip_.contains(inst);
}

FusionDecision CpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
int64_t operand_index) {
if (ShouldSkip(consumer)) {
return FusionDecision::Forbid(
"Don't fuse instructions from custom fusions/calls");
}

HloInstruction* producer = consumer->mutable_operand(operand_index);
VLOG(2) << "Considering for fusion: operand " << operand_index << " of "
<< consumer->ToString();
Expand Down
8 changes: 8 additions & 0 deletions xla/service/cpu/cpu_instruction_fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class CpuInstructionFusion : public InstructionFusion {
const absl::flat_hash_set<absl::string_view>&
execution_threads) override {
fusion_node_evaluations_.clear();
ComputeInstructionsToSkip(module, execution_threads);
return InstructionFusion::Run(module, execution_threads);
}

Expand All @@ -62,10 +63,17 @@ class CpuInstructionFusion : public InstructionFusion {
// Returns if a constant is large enough to be considered a large constant.
bool IsLargeConstant(const HloInstruction* constant) const;

bool ShouldSkip(const HloInstruction* inst) const;
void ComputeInstructionsToSkip(
HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads);

// Keep track of the number of times each instruction inside a fusion node is
// indexed with different index vectors.
absl::flat_hash_map<const HloInstruction*, FusionNodeIndexingEvaluation>
fusion_node_evaluations_;

absl::flat_hash_set<const HloInstruction*> instructions_to_skip_;
};

} // namespace cpu
Expand Down
52 changes: 52 additions & 0 deletions xla/service/cpu/cpu_instruction_fusion_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -976,5 +976,57 @@ ENTRY main {
HloOpcode::kConstant, HloOpcode::kAdd, HloOpcode::kAdd});
}

TEST_F(InstructionFusionTest, SkipCustomFusions) {
absl::string_view module_string = R"(
HloModule module
%fused_computation (param_0: f32[10,10], param_1: f32[10,10]) -> f32[10,10] {
%param_0 = f32[10,10]{1,0} parameter(0)
%param_1 = f32[10,10]{1,0} parameter(1)
%add = f32[10,10]{1,0} add(f32[10,10]{1,0} %param_0, f32[10,10]{1,0} %param_1)
%subtract = f32[10,10]{1,0} subtract(f32[10,10]{1,0} %param_0, f32[10,10]{1,0} %param_1)
ROOT %multiply = f32[10,10]{1,0} multiply(f32[10,10]{1,0} %add, f32[10,10]{1,0} %subtract)
}
ENTRY %main (Arg_0: f32[10,10], Arg_1: f32[10,10]) -> f32[10,10] {
%Arg_0 = f32[10,10]{1,0} parameter(0), metadata={op_name="x"}
%Arg_1 = f32[10,10]{1,0} parameter(1), metadata={op_name="y"}
ROOT %subtract_multiply_fusion = f32[10,10]{1,0} fusion(f32[10,10]{1,0} %Arg_0, f32[10,10]{1,0} %Arg_1), kind=kCustom, calls=%fused_computation
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(module_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
CpuInstructionFusion().Run(module.get()));
EXPECT_FALSE(changed);
}

TEST_F(InstructionFusionTest, SkipComputationsAttachedToCustomCalls) {
absl::string_view module_string = R"(
HloModule module
%custom_computation (param_0: f32[10,10], param_1: f32[10,10]) -> f32[10,10] {
%param_0 = f32[10,10]{1,0} parameter(0)
%param_1 = f32[10,10]{1,0} parameter(1)
%add = f32[10,10]{1,0} add(f32[10,10]{1,0} %param_0, f32[10,10]{1,0} %param_1)
%subtract = f32[10,10]{1,0} subtract(f32[10,10]{1,0} %param_0, f32[10,10]{1,0} %param_1)
ROOT %multiply = f32[10,10]{1,0} multiply(f32[10,10]{1,0} %add, f32[10,10]{1,0} %subtract)
}
ENTRY %main (Arg_0: f32[10,10], Arg_1: f32[10,10]) -> f32[10,10] {
%Arg_0 = f32[10,10]{1,0} parameter(0), metadata={op_name="x"}
%Arg_1 = f32[10,10]{1,0} parameter(1), metadata={op_name="y"}
ROOT %custom_call = f32[10,10]{1,0} custom-call(f32[10,10]{1,0} %Arg_0, f32[10,10]{1,0} %Arg_1), custom_call_target="target", called_computations={%custom_computation}
}
)";

TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(module_string));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
CpuInstructionFusion().Run(module.get()));
EXPECT_FALSE(changed);
}

} // namespace
} // namespace xla::cpu

0 comments on commit 4bc8f20

Please sign in to comment.