Skip to content

Commit

Permalink
[XLA:SchedulingAnnotations] Support having multiple computations cont…
Browse files Browse the repository at this point in the history
…aining the same scheduling annotation. Treat the same-id annotation groups from different computations independently.

PiperOrigin-RevId: 716446352
  • Loading branch information
seherellis authored and Google-ML-Automation committed Jan 17, 2025
1 parent 4c82a98 commit ed65641
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 116 deletions.
2 changes: 1 addition & 1 deletion xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1186,6 +1186,7 @@ cc_library(
"//xla/hlo/ir:hlo",
"//xla/hlo/ir:ptrvec",
"//xla/hlo/pass:hlo_pass",
"//xla/tsl/platform:errors",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
Expand All @@ -1197,7 +1198,6 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:statusor",
],
)

Expand Down
30 changes: 19 additions & 11 deletions xla/service/latency_hiding_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/status_macros.h"
#include "xla/tsl/platform/errors.h"
#include "xla/util.h"
#include "xla/xla.pb.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace {
Expand Down Expand Up @@ -1618,17 +1618,18 @@ absl::StatusOr<HloGraphNode*> FindAndExtractBestAnnotatedNode(
}

absl::Status DefaultSchedulerCore::ScheduleAnnotation(
int64_t annotation,
const HloComputation* computation, int64_t annotation,
DefaultSchedulerCore::SchedulingState* sched_state) const {
// Create the ready set with the roots of the annotation
TF_RET_CHECK(sched_state->annotation_ready.empty());
for (const HloInstruction* instr :
annotation_tracker_->GetRootInstructions(annotation)) {
annotation_tracker_->GetRootInstructions(computation, annotation)) {
sched_state->annotation_ready.push_back(
&sched_state->sched_graph.GetNode(instr));
}
int64_t num_scheduled = 0;
int64_t annotation_size = annotation_tracker_->GetNumInstructions(annotation);
int64_t annotation_size =
annotation_tracker_->GetNumInstructions(computation, annotation);
while (!sched_state->annotation_ready.empty()) {
// Print the current annotation ready queue.
VLOG(2) << "Current annotation ready queue:";
Expand Down Expand Up @@ -1863,11 +1864,13 @@ absl::StatusOr<HloGraphNode::TimeCost> DefaultSchedulerCore::ScheduleNode(
<< " ready_num_nodes_with_annotation: "
<< sched_state->ready_num_nodes_with_annotation[annotation]
<< " num_root_instructions: "
<< annotation_tracker_->GetNumRootInstructions(annotation);
<< annotation_tracker_->GetNumRootInstructions(
n->GetInstr().parent(), annotation);
// LegalizeSchedulingAnnotations pass should have made sure that we will
// eventually reach a state where all of the annotation root instructions
// will be ready at the same time.
if (annotation_tracker_->GetNumRootInstructions(annotation) ==
if (annotation_tracker_->GetNumRootInstructions(n->GetInstr().parent(),
annotation) ==
sched_state->ready_num_nodes_with_annotation[annotation]) {
sched_state->ready_annotations.push_back(annotation);
}
Expand Down Expand Up @@ -2245,7 +2248,7 @@ void HloScheduleGraph::AnnotateGraph(
const HloComputation* comp = original_order_[0]->parent();
for (int64_t annotation : annotation_tracker->GetAnnotations(comp)) {
for (const HloInstruction* instr :
annotation_tracker->GetInstructions(annotation)) {
annotation_tracker->GetInstructions(comp, annotation)) {
HloGraphNode& node = GetNode(instr);
TF_CHECK_OK(node.SetAnnotation(annotation));
}
Expand Down Expand Up @@ -2306,8 +2309,10 @@ absl::Status DefaultSchedulerCore::SchedulingStep(

bool DefaultSchedulerCore::SchedulingAnnotationCrossesOverlapLimit(
const SchedulingState& sched_state, int64_t annotation) {
const HloComputation* comp =
sched_state.sched_graph.GetOriginalInstrList()[0]->parent();
for (const HloInstruction* instr :
annotation_tracker_->GetInstructions(annotation)) {
annotation_tracker_->GetInstructions(comp, annotation)) {
if (scheduling_instruction_crosses_overlap_limit_(
sched_state, &sched_state.sched_graph.GetNode(instr))) {
return true;
Expand Down Expand Up @@ -2359,8 +2364,10 @@ DefaultSchedulerCore::ScheduleComputation(const HloComputation* computation) {
<< " ready_num_nodes_with_annotation: "
<< sched_state.ready_num_nodes_with_annotation[annotation]
<< " num_root_instructions: "
<< annotation_tracker_->GetNumRootInstructions(annotation);
if (annotation_tracker_->GetNumRootInstructions(annotation) ==
<< annotation_tracker_->GetNumRootInstructions(computation,
annotation);
if (annotation_tracker_->GetNumRootInstructions(computation,
annotation) ==
sched_state.ready_num_nodes_with_annotation[annotation]) {
sched_state.ready_annotations.push_back(annotation);
}
Expand Down Expand Up @@ -2400,7 +2407,8 @@ DefaultSchedulerCore::ScheduleComputation(const HloComputation* computation) {
sched_state.ready_annotations.pop_back();
VLOG(2) << "------- BEGIN ANNOTATION: " << annotation << " -------";
sched_state.ongoing_annotation = annotation;
TF_RETURN_IF_ERROR(ScheduleAnnotation(annotation, &sched_state));
TF_RETURN_IF_ERROR(
ScheduleAnnotation(computation, annotation, &sched_state));
VLOG(2) << "------- END ANNOTATION: " << annotation << " --------";
sched_state.ongoing_annotation = -1;
continue;
Expand Down
63 changes: 36 additions & 27 deletions xla/service/latency_hiding_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -343,12 +343,10 @@ class AnnotationTracker {
absl::flat_hash_set<int64_t> annotations;
for (const HloInstruction* instr : comp->instructions()) {
if (auto annotation = GetAnnotation(instr)) {
// LegalizeSchedulingAnnotations pass should have made sure that the
// same annotation id is not used in multiple computations.
if (annotations.insert(annotation.value()).second) {
comp_annotation_map_[comp].push_back(annotation.value());
}
annotations_[annotation.value()].push_back(instr);
annotations_[annotation.value()][comp].push_back(instr);
}
}
}
Expand All @@ -367,16 +365,19 @@ class AnnotationTracker {
return std::nullopt;
}
std::vector<const HloInstruction*> GetInstructions(
const int64_t annotation) const {
return annotations_.at(annotation);
const HloComputation* comp, const int64_t annotation) const {
return annotations_.at(annotation).at(comp);
}
int64_t GetNumInstructions(const int64_t annotation) {
return annotations_[annotation].size();
int64_t GetNumInstructions(const HloComputation* comp,
const int64_t annotation) {
return annotations_[annotation][comp].size();
}
void FindAnnotationRoots(const int64_t annotation) {
void FindAnnotationRoots(const HloComputation* comp,
const int64_t annotation) {
absl::flat_hash_set<const HloInstruction*> seen_instructions(
annotations_[annotation].begin(), annotations_[annotation].end());
for (const HloInstruction* instr : annotations_.at(annotation)) {
annotations_[annotation][comp].begin(),
annotations_[annotation][comp].end());
for (const HloInstruction* instr : annotations_.at(annotation).at(comp)) {
bool has_annotated_user = false;
for (const PtrVec<HloInstruction*>& users :
{instr->users(), instr->control_successors()}) {
Expand All @@ -389,29 +390,32 @@ class AnnotationTracker {
}
if (!has_annotated_user) {
VLOG(3) << "Annotation: " << annotation << ", root: " << instr->name();
annotation_roots_[annotation].push_back(instr);
annotation_roots_[annotation][comp].push_back(instr);
}
}
}
int64_t GetNumRootInstructions(const int64_t annotation) {
if (!annotation_roots_.contains(annotation)) {
FindAnnotationRoots(annotation);
int64_t GetNumRootInstructions(const HloComputation* comp,
const int64_t annotation) {
if (!annotation_roots_[annotation].contains(comp)) {
FindAnnotationRoots(comp, annotation);
}
return annotation_roots_[annotation].size();
return annotation_roots_[annotation][comp].size();
}
std::vector<const HloInstruction*> GetRootInstructions(
const int64_t annotation) {
const HloComputation* comp, const int64_t annotation) {
if (!annotation_roots_.contains(annotation)) {
FindAnnotationRoots(annotation);
FindAnnotationRoots(comp, annotation);
}
return annotation_roots_[annotation];
return annotation_roots_[annotation][comp];
}
void PrintAnnotationSets(int64_t level) const {
for (const auto& [annotation, instrs] : annotations_) {
VLOG(level) << "Annotation " << annotation << " has " << instrs.size()
<< " instructions";
for (const HloInstruction* instr : instrs) {
VLOG(level) << " " << instr->name();
for (const auto& [annotation, comp_instr_vector] : annotations_) {
for (const auto& [comp, instrs] : comp_instr_vector) {
VLOG(level) << "Annotation " << annotation << " has " << instrs.size()
<< " instructions in computation " << comp->name();
for (const HloInstruction* instr : instrs) {
VLOG(level) << " " << instr->name();
}
}
}
}
Expand All @@ -420,8 +424,13 @@ class AnnotationTracker {
const HloModule* module_;
absl::flat_hash_map<const HloComputation*, std::vector<int64_t>>
comp_annotation_map_;
absl::flat_hash_map<int64_t, std::vector<const HloInstruction*>> annotations_;
absl::flat_hash_map<int64_t, std::vector<const HloInstruction*>>
absl::flat_hash_map<int64_t,
absl::flat_hash_map<const HloComputation*,
std::vector<const HloInstruction*>>>
annotations_;
absl::flat_hash_map<int64_t,
absl::flat_hash_map<const HloComputation*,
std::vector<const HloInstruction*>>>
annotation_roots_;
};

Expand Down Expand Up @@ -1064,9 +1073,9 @@ class DefaultSchedulerCore : public SchedulerCore {
absl::Status AnnotatedSchedulingStep(
HloGraphNode* node,
DefaultSchedulerCore::SchedulingState* sched_state) const;
// Schedules all of the nodes in the given annotation.
// Schedules all of the nodes with the given annotation in computation.
absl::Status ScheduleAnnotation(
int64_t annotation,
const HloComputation* computation, int64_t annotation,
DefaultSchedulerCore::SchedulingState* sched_state) const;
// Update node that has been scheduled.
virtual absl::StatusOr<HloGraphNode::TimeCost> ScheduleNode(
Expand Down
66 changes: 66 additions & 0 deletions xla/service/latency_hiding_scheduler_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3848,4 +3848,70 @@ ENTRY entry {
GetIndex(new_instruction_sequence, "cp2s"));
}

TEST_F(LatencyHidingSchedulerTest, CrossComputationAnnotation) {
constexpr absl::string_view hlo_string = R"(
HloModule module, is_scheduled=true
while_cond {
param = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) parameter(0)
ROOT gte = pred[] get-tuple-element(param), index=2
}
while_body {
param = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) parameter(0)
gte0 = f32[16,64,256]{2,1,0} get-tuple-element(param), index=0
gte1 = f32[16,64,256]{2,1,0} get-tuple-element(param), index=1
gte2 = pred[] get-tuple-element(param), index=2
cps1 = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, u32[], u32[]) collective-permute-start(gte1), source_target_pairs={{0,1},{1,2},{2,3},{3,0}}, frontend_attributes={_scheduling_group_id="1"}
cpd1 = f32[16,64,256]{2,1,0} collective-permute-done(cps1), frontend_attributes={_scheduling_group_id="1"}
c1 = f32[16,256,256]{2,1,0} convolution(gte0, gte0), window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb, frontend_attributes={_scheduling_group_id="1"}
slice = f32[16,64,256]{2,1,0} slice(c1), slice={[0:16], [0:64], [0:256]}
add = f32[16,64,256]{2,1,0} add(gte0, slice)
ROOT tuple = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) tuple(add, cpd1, gte2)
}
ENTRY entry {
p0 = f32[256,1024]{1,0} parameter(0)
p1 = f32[16,64,256]{2,1,0} parameter(1)
p2 = f32[16,64,256]{2,1,0} parameter(2)
p3 = pred[] parameter(3)
c0 = f32[16,256,256]{2,1,0} convolution(p1, p2), window={size=16 stride=15 lhs_dilate=16}, dim_labels=0fb_0io->0fb, frontend_attributes={_scheduling_group_id="1"}
ags0 = (f32[256,1024]{1,0}, f32[1024,1024]{1,0}) all-gather-start(p0), replica_groups={{0,1,2,3}}, dimensions={0}, frontend_attributes={_scheduling_group_id="1"}
tuple = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) tuple(p1, p2, p3)
while = (f32[16,64,256]{2,1,0}, f32[16,64,256]{2,1,0}, pred[]) while(tuple), condition=while_cond, body=while_body
agd0 = f32[1024,1024]{1,0} all-gather-done(ags0), frontend_attributes={_scheduling_group_id="1"}
gte = f32[16,64,256]{2,1,0} get-tuple-element(while), index=0
ROOT tuple1 = (f32[16,64,256]{2,1,0}, f32[16,256,256]{2,1,0}, f32[1024,1024]{1,0}) tuple(gte, c0, agd0)
}
)";
TF_ASSERT_OK_AND_ASSIGN(auto hlo_module, ParseHloText(hlo_string));
HloSchedule& module_schedule = hlo_module->schedule();
EXPECT_TRUE(hlo_module->has_entry_computation());
auto sched_config = GetDefaultSchedConfig();
EXPECT_TRUE(RunScheduler(hlo_module.get(), sched_config,
std::make_unique<TestLatencyEstimator>())
.ok());
EXPECT_TRUE(hlo_module->has_entry_computation());

std::vector<HloInstruction*> new_instruction_sequence =
module_schedule.sequence(hlo_module->entry_computation()).instructions();
if (VLOG_IS_ON(1)) {
for (auto* new_i : new_instruction_sequence) {
VLOG(1) << new_i->ToString();
}
}

EXPECT_LT(GetIndex(new_instruction_sequence, "ags0"),
GetIndex(new_instruction_sequence, "c0"));
EXPECT_LT(GetIndex(new_instruction_sequence, "c0"),
GetIndex(new_instruction_sequence, "agd0"));
const HloInstruction* while_inst = FindInstruction(hlo_module.get(), "while");
std::vector<HloInstruction*> loop_instruction_sequence =
module_schedule.sequence(while_inst->while_body()).instructions();
EXPECT_LT(GetIndex(loop_instruction_sequence, "cps1"),
GetIndex(loop_instruction_sequence, "c1"));
EXPECT_LT(GetIndex(loop_instruction_sequence, "c1"),
GetIndex(loop_instruction_sequence, "cpd1"));
}

} // namespace xla
Loading

0 comments on commit ed65641

Please sign in to comment.