Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower distributed matmul to pipelined algorithm for fine-grained overlap: AG+GEMM layout #3606

Merged
merged 51 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
51 commits
Select commit Hold shift + click to select a range
38721fe
Host IR: add GetCurrentStream
samnordmann Dec 18, 2024
c4ca266
lint
samnordmann Dec 18, 2024
b517c2b
lower to collective base pipeline AG+GEMM
samnordmann Dec 18, 2024
92ab927
lint
samnordmann Dec 18, 2024
ed4440a
lint
samnordmann Dec 18, 2024
ef8f00c
update with non blocking stream synchronization
samnordmann Dec 18, 2024
36fd2be
make stream synchronization non blocking
samnordmann Dec 18, 2024
1e9f1d0
lint
samnordmann Dec 18, 2024
af06de4
add event to events_ container
samnordmann Dec 18, 2024
5e166a0
destroy event async at create site
samnordmann Dec 18, 2024
e8ffadb
Merge branch 'host_irs/non_blocking_stream_synchronize' into overlap/…
samnordmann Dec 18, 2024
741202b
minor review
samnordmann Dec 23, 2024
353c03c
Merge branch 'main' of github.com:NVIDIA/Fuser into host_irs/get_curr…
samnordmann Dec 23, 2024
4420eb4
Merge branch 'host_irs/get_current_stream' into overlap/lower_matmul_…
samnordmann Dec 23, 2024
d0a9340
Merge branch 'main' of github.com:NVIDIA/Fuser into overlap/lower_mat…
samnordmann Dec 23, 2024
0374604
fix merge
samnordmann Dec 23, 2024
5e07ad8
minor review
samnordmann Dec 23, 2024
b546dce
remove now unnecessary trick of adding artifical outputs
samnordmann Dec 23, 2024
8e8b247
lint
samnordmann Dec 23, 2024
d5b42c2
remove now unnecessary patch on broadcast
samnordmann Dec 23, 2024
ac20f4f
Merge branch 'main' into overlap/lower_matmul_to_hostir
samnordmann Dec 30, 2024
4191ecf
fix typo in canLower
samnordmann Jan 3, 2025
dfc33f2
add Stream parallelType
samnordmann Jan 3, 2025
4e1ecb9
minor reviews
samnordmann Jan 3, 2025
5490085
Merge branch 'main' of github.com:NVIDIA/Fuser into overlap/lower_mat…
samnordmann Jan 3, 2025
a5c70b8
fix bug: allocate dst buffer before posting communication
samnordmann Jan 3, 2025
f66e97a
change order of presegpass
samnordmann Jan 7, 2025
dba90e3
Merge branch 'main' of github.com:NVIDIA/Fuser into overlap/lower_mat…
samnordmann Jan 7, 2025
09ccfa5
minor comment
samnordmann Jan 7, 2025
8287679
fix MultiDeviceReductionTest.UnshardedInput_ShardedOutput/ tests
samnordmann Jan 7, 2025
301e54d
bypass ReorderShardedAxisPass if multiple IO. fix DistributedTransfor…
samnordmann Jan 7, 2025
0c9b65e
change tensor size to loosen tolerance/error in DistributedMatmulTest…
samnordmann Jan 7, 2025
3a0d827
lint
samnordmann Jan 7, 2025
decb055
minor comments
samnordmann Jan 7, 2025
11843d6
typo
samnordmann Jan 7, 2025
caf5d0b
increase tolerance rate
samnordmann Jan 8, 2025
5b6c7bd
still throws if two axis are DIDx, even if one is reduced
samnordmann Jan 8, 2025
632aa1e
support multiple additions/deletions in isInnerResharding
samnordmann Jan 8, 2025
4dde7d7
Merge branch 'main' of github.com:NVIDIA/Fuser into overlap/lower_mat…
samnordmann Jan 9, 2025
8f60b45
use randint and small sizes in DistributedMatmulTest.AnnotateWeightOnly
samnordmann Jan 9, 2025
cdd9e46
add bool option ignore_inner_resharding in canLower
samnordmann Jan 9, 2025
6e9fe35
de-DID-parallelize reduction axis in shardAllLike
samnordmann Jan 9, 2025
428600d
revert patch on isSharded
samnordmann Jan 9, 2025
e70c00f
revert patch on getShardedLogicalAxis and isInnerResharding
samnordmann Jan 9, 2025
0f93a43
revert accepting multiple IO in ReorderShardedAxisPass
samnordmann Jan 9, 2025
6b11d33
revert switching order of passes ReorderShardedAxisPass and InsertRes…
samnordmann Jan 9, 2025
061955f
lint
samnordmann Jan 9, 2025
1a83908
move ignore_inner_resharding as LHS of bool op to lazy evaluate the p…
samnordmann Jan 9, 2025
d506dde
minor comments
samnordmann Jan 13, 2025
a367238
Merge branch 'main' of github.com:NVIDIA/Fuser into overlap/lower_mat…
samnordmann Jan 13, 2025
ea55060
lint
samnordmann Jan 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/host_ir/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ HostIrEvaluator::HostIrEvaluator(
{container_->getDefaultStream(),
c10::cuda::getDefaultCUDAStream(
static_cast<c10::DeviceIndex>(device_index))});
expr_evaluator_.bind("numberOfStreams", params_.number_of_streams);
}

std::vector<at::Tensor> HostIrEvaluator::runWithInput(
Expand Down
3 changes: 3 additions & 0 deletions csrc/host_ir/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ struct HostIrEvaluatorParams {
// Experimental: whether to cache fusion executor. WAR: avoid recompilation
// but implicitely assumes that the input shape don't change over iterations
bool cache_fusion_executor = false;
// number of additional cuda streams to use at runtime for comm+compute
// pipelining
int64_t number_of_streams = 4;
};

class HostIrEvaluator final : public OptOutDispatch {
Expand Down
2 changes: 2 additions & 0 deletions csrc/host_ir/host_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ class Wait : public Expr {
}
};

// Makes the current stream wait on the given stream. Non-blocking from the host
// point of view.
class Synchronize : public Expr {
public:
using Expr::Expr;
Expand Down
164 changes: 143 additions & 21 deletions csrc/host_ir/lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <multidevice/device_mesh.h>
#include <multidevice/utils.h>
#include <ops/all_ops.h>
#include <ops/utils.h>
#include <preseg_passes/insert_reshardings.h>
#include <preseg_passes/make_resharding_contiguous.h>
#include <preseg_passes/propagate_shardings.h>
Expand Down Expand Up @@ -235,6 +236,10 @@ void lowerToReduceScatter(
std::vector<Expr*> HostIrLower::lower(Expr* c) {
FusionGuard fg(c->fusion());

if (c->isA<MatmulOp>()) {
return lowerToCollectiveBasedPipelinedGemmComm(c);
}

std::vector<Expr*> comms;
NVF_ERROR(
c->inputs().size() == 1 && c->input(0)->isA<TensorView>() &&
Expand Down Expand Up @@ -309,9 +314,12 @@ bool HostIrLower::canLower(Expr* expr) {
if (!ir_utils::isTvOp(expr)) {
return false;
}
if (expr->isA<ReductionOp>()) {
auto in = expr->as<ReductionOp>()->in()->as<TensorView>();
auto out = expr->as<ReductionOp>()->out()->as<TensorView>();
if (auto* reduction = dynamic_cast<ReductionOp*>(expr)) {
if (isInnerResharding(expr)) {
return false;
}
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
auto in = reduction->in()->as<TensorView>();
auto out = reduction->out()->as<TensorView>();
// get the reduced axis
std::vector<IterDomain*> reduction_axis;
std::copy_if(
Expand All @@ -328,10 +336,124 @@ bool HostIrLower::canLower(Expr* expr) {
PairwiseLogicalDomainMap(in, out).mapConsumerToProducer();
auto c2p_map_it = c2p_map.find(reduction_axis.at(0));
return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim();
} else {
return expr->isA<LoadStoreOp>() &&
(expr->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set);
} else if (auto* ldst = dynamic_cast<LoadStoreOp*>(expr)) {
return !isInnerResharding(ldst) &&
ldst->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set;
} else if (auto* matmul = dynamic_cast<MatmulOp*>(expr)) {
// For now we only support c = matmul(a,b) when b,c are fully replicated and
// a is sharded on axis 1
return !isSharded(matmul->inB()) && !isSharded(matmul->out()) &&
matmul->inA()->axis(0)->getParallelType() == ParallelType::Serial &&
getShardedLogicalAxis(matmul->inA(), ParallelType::DIDx) == 1 &&
matmul->out()->axis(0)->getParallelType() == ParallelType::Stream;
}
return false;
}

std::vector<Expr*> HostIrLower::lowerToCollectiveBasedPipelinedGemmComm(
Expr* expr) {
auto matmul = expr->as<MatmulOp>();
NVF_ERROR(matmul != nullptr, "Expect a MatmulOp, got", expr);
TensorView* tva = matmul->inA();
TensorView* tvb = matmul->inB();
TensorView* tvc = matmul->out();
NVF_ERROR(
!isSharded(tvb), "The B operand ", tvb, " is expected to not be sharded");
NVF_ERROR(
!isSharded(tvc),
"The output ",
matmul->out(),
" is expected to not be sharded");
const int64_t sharded_axis_index =
getShardedLogicalAxis(tva, ParallelType::DIDx);
IterDomain* stream_axis = tva->axis(0);
NVF_ERROR(
stream_axis->getParallelType() == ParallelType::Serial &&
sharded_axis_index == 1,
"The operand A ",
tva,
" is expected to be sharded on the dimension 1");

auto hic = FusionGuard::getCurFusion()->as<hir::HostIrContainer>();

auto* get_current_stream = IrBuilder::create<hir::GetCurrentStream>();
hir::Stream* original_stream = get_current_stream->stream();

TensorView* tva_allgathered =
ops::newValLike(tva, tva->dtype())->as<TensorView>();
tva_allgathered->axis(sharded_axis_index)->parallelize(ParallelType::Serial);
tva_allgathered->setMemoryType(MemoryType::Global);
auto* allocate_tva_allgathered =
IrBuilder::create<kir::Allocate>(tva_allgathered, MemoryType::Global);

tvc->setMemoryType(MemoryType::Global);
auto* allocate_tvc =
IrBuilder::create<kir::Allocate>(tvc, MemoryType::Global);

auto* j =
IrBuilder::create<Val>(DataType::Index); // running index of the for-loop
auto* start = hic->zeroVal();
auto* stop = stream_axis->extent();
auto* step = hic->oneVal();
auto* for_loop = IrBuilder::create<ForLoop>(
stream_axis,
/*index=*/j,
start,
stop,
step,
/*vectorize=*/false,
/*vectorize_shift=*/nullptr,
/*unroll_required=*/false,
CircularBufferLoopStage::NotApplicable,
/*circular_buffer_loop_stage_depth=*/0);

auto* number_of_streams =
IrBuilder::create<NamedScalar>("numberOfStreams", DataType::Int);
auto* stream_index = mod(j, number_of_streams);
auto* stream = IrBuilder::create<hir::Stream>(stream_index);
auto* set_stream = IrBuilder::create<hir::SetCurrentStream>(stream);

TensorView* tva_j = select(tva, 0, j);
TensorView* tva_allgathered_j = select(tva_allgathered, 0, j);
TensorView* tvc_j = select(tvc, 0, j);

NVF_ERROR(
tva->hasDeviceMesh(),
"The matmul's input ",
tva,
"is expected to have a DeviceMesh");
for (auto tv : {tva_j, tva_allgathered_j, tvc_j}) {
tv->setDeviceMesh(tva->getDeviceMesh());
}

auto* communication = IrBuilder::create<Communication>(
CommunicationType::Allgather,
/*out=*/tva_allgathered_j,
/*in=*/tva_j,
/*team=*/tva->getDeviceMesh().vector());
auto* wait = IrBuilder::create<hir::Wait>(communication);

auto* mm = IrBuilder::create<MatmulOp>(tvc_j, tva_allgathered_j, tvb);
wujingyue marked this conversation as resolved.
Show resolved Hide resolved

auto* set_back_original_stream =
IrBuilder::create<hir::SetCurrentStream>(original_stream);
auto* sync_stream = IrBuilder::create<hir::Synchronize>(stream);

std::vector<Expr*> loop_body = {
set_stream,
tva_j->definition(),
tva_allgathered_j->definition(),
communication,
wait,
tvc_j->definition(),
mm,
set_back_original_stream,
sync_stream};
for (Expr* expr : loop_body) {
for_loop->body().push_back(expr);
}

return {get_current_stream, allocate_tva_allgathered, allocate_tvc, for_loop};
}

std::unique_ptr<hir::HostIrContainer> HostIrLower::lower(
Expand All @@ -341,10 +463,10 @@ std::unique_ptr<hir::HostIrContainer> HostIrLower::lower(
// Note: passes run before PreSegmenter optimization passes.
preseg_passes::OptimizationPass<
preseg_passes::PropagateShardingsPass>::runPass(fusion.get());
preseg_passes::OptimizationPass<
preseg_passes::InsertReshardingsPass>::runPass(fusion.get());
preseg_passes::OptimizationPass<
preseg_passes::ReorderShardedAxisPass>::runPass(fusion.get());
preseg_passes::OptimizationPass<
preseg_passes::InsertReshardingsPass>::runPass(fusion.get());
preseg_passes::OptimizationPass<
preseg_passes::MakeReshardingContiguousPass>::runPass(fusion.get());

Expand Down Expand Up @@ -397,20 +519,20 @@ std::unique_ptr<hir::HostIrContainer> HostIrLower::lower(
for (auto* expr :
HostIrLower::lower(ir_cloner.clone(group->exprs().at(0)))) {
// Allocate the recv buffers of communications
NVF_ERROR(
expr->isA<Communication>(),
"Expected a Communication but got ",
expr);
auto* communication = expr->as<Communication>();
TensorView* tv = communication->out();
if (tv->getDeviceMesh().has(my_device_index)) {
auto* allocate =
IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
hic->pushBackTopLevelExprs(allocate);
if (expr->isA<Communication>()) {
auto* communication = expr->as<Communication>();
TensorView* tv = communication->out();
if (tv->getDeviceMesh().has(my_device_index)) {
auto* allocate =
IrBuilder::create<kir::Allocate>(tv, MemoryType::Global);
hic->pushBackTopLevelExprs(allocate);
}
}
hic->pushBackTopLevelExprs(expr);
if (expr->isA<Communication>()) {
auto wait = IrBuilder::create<hir::Wait>(expr->as<Communication>());
hic->pushBackTopLevelExprs(wait);
}
hic->pushBackTopLevelExprs(communication);
auto wait = IrBuilder::create<hir::Wait>(communication);
hic->pushBackTopLevelExprs(wait);
}
} else {
auto host_unit = IrBuilder::create<hir::HostUnit>(
Expand Down
3 changes: 3 additions & 0 deletions csrc/host_ir/lower.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ class HostIrLower {
static std::unique_ptr<hir::HostIrContainer> lower(
std::unique_ptr<Fusion> fusion,
int64_t my_device_index);

private:
static std::vector<Expr*> lowerToCollectiveBasedPipelinedGemmComm(Expr* expr);
};

} // namespace nvfuser
8 changes: 5 additions & 3 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges
bool isSharded(const TensorView* tv) {
bool is_sharded = false;
for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) {
if (!alloc_id->isDeviceDim()) {
if (!alloc_id->isDeviceDim() || alloc_id->isReduction()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This along with several other changes is for supporting rDID? That's intentionally unsupported because rDID (unlike r) means the data only exists in one GPU and the collectives nvFuser practically uses today (e.g. allreduce and reducescatter) don't do that. I'll need to figure out where rDID came from. It's unexpected because your test doesn't try to parallelize a reduction dimension in the first place.

cc @naoyam

Copy link
Collaborator Author

@samnordmann samnordmann Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This along with several other changes is for supporting rDID?

This patch is not needed in my test, but indeed the case rDID appears in other tests which were previously added.

That's intentionally unsupported because

In this case, we should assert that this case is not encountered (but it actually is in the present state). For now nothing prevents this case to occur:

.iter_type(IterType::Reduction)

For example, take ReduceScatter/PipelineTestTwoStages.Communication/7 with GetParam() = (NCCL, DeviceMesh{1 0}, DeviceMesh{}, true, true, true, 1, false),

Another more subtle case occurs from InsertReshardingsPass. Take for example MultiDeviceReductionTest.UnshardedInput_ShardedOutput/symbolic_sharded_along_dim_0. Place a break point at csrc/preseg_passes/pre_segmenter.cpp:43, aka just before applying OptimizationPass<InsertReshardingsPass> on the fusion.

Before this pass, the fusion reads as

T1_g_float[ideviceIdx.x4{i0}, iS5{i2}, iS6{i3}, iS7{i4}] (DeviceMesh{0 1})
   = Set( T0_g_float[iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1}), cache_op=Streaming )
T2_g_float[ideviceIdx.x8{i0}, iS9{i2}, iS10{i3}, iS11{i4}] (DeviceMesh{0 1})
   = T1_g_float[ideviceIdx.x4{i0}, iS5{i2}, iS6{i3}, iS7{i4}] (DeviceMesh{0 1})
   + T1_g_float[ideviceIdx.x4{i0}, iS5{i2}, iS6{i3}, iS7{i4}] (DeviceMesh{0 1});
T3_g_float[rS12{i0}, iS13{i2}, iS14{i3}, ideviceIdx.x15{i4}] (DeviceMesh{0 1})
   = reduction( T2_g_float[ideviceIdx.x8{i0}, iS9{i2}, iS10{i3}, iS11{i4}] (DeviceMesh{0 1}), op = add, initial value = float(0), allreduce = false )

and after this pass:

T1_g_float[ideviceIdx.x4{i0}, iS5{i2}, iS6{i3}, iS7{i4}] (DeviceMesh{0 1})
   = Set( T0_g_float[iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1}), cache_op=Streaming )
T2_g_float[ideviceIdx.x8{i0}, iS9{i2}, iS10{i3}, iS11{i4}] (DeviceMesh{0 1})
   = T1_g_float[ideviceIdx.x4{i0}, iS5{i2}, iS6{i3}, iS7{i4}] (DeviceMesh{0 1})
   + T1_g_float[ideviceIdx.x4{i0}, iS5{i2}, iS6{i3}, iS7{i4}] (DeviceMesh{0 1});
T3_l_float[rdeviceIdx.x12{i0}, iS13{i2}, iS14{i3}, iS15{i4}] (DeviceMesh{0 1})
   = reduction( T2_g_float[ideviceIdx.x8{i0}, iS9{i2}, iS10{i3}, iS11{i4}] (DeviceMesh{0 1}), op = add, initial value = float(0), allreduce = false )
T4_g_float[iS16{i2}, iS17{i3}, ideviceIdx.x18{i4}] (DeviceMesh{0 1})
   = Set( T3_l_float[rdeviceIdx.x12{i0}, iS13{i2}, iS14{i3}, iS15{i4}] (DeviceMesh{0 1}), cache_op=Streaming )

and we see that rdeviceIdx.x appears.

rDID (unlike r) means the data only exists in one GPU and the collectives nvFuser practically uses today (e.g. allreduce and reducescatter) don't do that.

Ok, but anyway I think the present patch on isSharded and other function are still relevant. Don't you agree?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking... I'll have to run an earlier version of this branch to understand what exactly failed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking... I'll have to run an earlier version of this branch to understand what exactly failed.

Yes, sorry about that. Let me know how I can help. We can set up a meeting if you want

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incidental note:

.iter_type(IterType::Reduction)
, what you pointed out, is correct. resetSchedulingParams() resets the parallel type to Serial so it'll be r instead of rDID.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another more subtle case occurs from InsertReshardingsPass

This is related to the isInnerResharding change in this PR, so I'll comment over there.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should assert that this case is not encountered

Yep, it's unfortunately one of the many places in nvFuser where a contract is not fully verified. And PRs are welcomed. In the meantime, how about moving

id->parallelize(ParallelType::Serial);
to shardAllLike? It has been the biggest source of rDID.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about moving

id->parallelize(ParallelType::Serial);

to shardAllLike? It has been the biggest source of rDID.

sounds good!

continue;
}

Expand Down Expand Up @@ -160,7 +160,7 @@ int64_t getShardedLogicalAxis(
std::unordered_map<ParallelType, IterDomain*> parallel_type_to_id =
mapDeviceParallelTypeToId(tv->getMaybeAllocationDomain());
IterDomain* alloc_id = getOrDefault(parallel_type_to_id, parallel_type);
if (alloc_id == nullptr) {
if (alloc_id == nullptr || alloc_id->isReduction()) {
return -1;
}

Expand Down Expand Up @@ -417,7 +417,9 @@ bool haveDifferentShardings(
.strictAreMapped(a, b);
};

if (!is_mapped_in_id_model(p_loop_id, c_loop_id, id_model)) {
if (!is_mapped_in_id_model(p_loop_id, c_loop_id, id_model) ||
(p_loop_id != nullptr && c_loop_id != nullptr &&
p_loop_id->isReduction() != c_loop_id->isReduction())) {
return true;
}
}
Expand Down
13 changes: 4 additions & 9 deletions csrc/preseg_passes/reorder_sharded_axis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,16 @@ void ReorderShardedAxisPass::runPass(Fusion* fusion) {
const std::vector<Expr*>& exprs = fusion->exprs();
for (auto it = std::rbegin(exprs); it != std::rend(exprs); it++) {
Expr* expr = *it;
if (!isResharding(expr)) {
if (HostIrLower::canLower(expr)) {
continue;
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
}
if (expr->outputs().size() > 1 || expr->inputs().size() > 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you code-comment this? I suspect this is to work around some limitations in insert_resharding and reorder_sharded_axis for the stream-parallelized matmul you are working on. Otherwise, all non-lowerable resharding expressions would have been decomposed.

Eventually, insert_resharding should generate the following

image

and reorder_sharded_axis should do nothing for the allgather because the DIDx axis is outermost allocated (note that S in the allgather output is stream-parallelized and therefore has allocation size of 1).

Copy link
Collaborator Author

@samnordmann samnordmann Jan 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you code-comment this? I suspect this is to work around some limitations in insert_resharding and reorder_sharded_axis for the stream-parallelized matmul you are working on. Otherwise, all non-lowerable resharding expressions would have been decomposed.

Before the patch, the pass was throwing an error if the expr had multiple I/O. After this patch, we don't throw, we only pass. There is nothing fundamental to that. When needed, in a future PR, we could extend this pass to also support multiple I/O. But anyway, we don't rely on this pass for distributed matmul test added by this patch.

and reorder_sharded_axis should do nothing for the allgather because the DIDx axis is outermost allocated (note that S in the allgather output is stream-parallelized and therefore has allocation size of 1).

That is not correct, the stream axis is fully allocated.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this patch, we don't throw, we only pass.

To rephrase my previous comment, I was trying to say this is a wrong change. InsertReshardingsPass (which runs before ReorderShardedAxis) should already have decomposed each resharding expression into local expressions and resharding expressions that can be lowered to a communication (modulo the axis order which this pass tries to fix). All communications today takes one TV and produces one TV, so there's nothing wrong with the old code here to error out when seeing a multiple-I/O resharding expression.

Therefore, I was trying to understand what triggered you to make this change. Was it to work around a limitation in InsertReshardingsPass?

That is not correct, the stream axis is fully allocated.

(I brought this up but this no longer matters for the current discussion around multiple I/O. But still I'd like to point out a potential misconception so we can be on the same page for the future!)

I don't think so. A stream-parallelized IterDomain in allocation (in your unit test the same as loop and logical) means the allocation for that axis is sharded, similar to how nvFuser deals with TID and BID. For the allgather output, the allocation ought to be size [1, D, M/S/D, K] and it ought to be done inside the for loop. When we aggressively run each loop iteration on a different stream, the total allocation will be the same as [S, D, M/S/D, K]; however, SOTA tends to limit concurrent streams so allocation is less than that. E.g., a double-buffer approach allocates [2, D, M/S/D, K].

That said, I understand your current implementation fully allocates the allgather output outside the loop. It's valid just suboptimal. To represent that, I'd like the loop to be stream parallelized but allocation to not be stream parallelized. However, doing so today may run into problems as we don't support DID loop split. So I'm definitely OK with some temporary workarounds.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Therefore, I was trying to understand what triggered you to make this change. Was it to work around a limitation in InsertReshardingsPass?

Without this change, DistributedTransformerTest throws.

(I brought this up but this no longer matters for the current discussion around multiple I/O. But still I'd like to point out a potential misconception so we can be on the same page for the future!)

no problem!

I got your point, which makes a lot of sense. The only thing I am still not sure to understand is:

reorder_sharded_axis should do nothing for the allgather because the DIDx axis is outermost allocated

In our case, according to your convention, loop axis is stream-parallelized, but allocation axis is not. Then reorder_sharded_axis should do nothing but DID is still not outermost allocated

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then reorder_sharded_axis should do nothing but DID is still not outermost allocated

It depends on the TV representing the allgather output.

For me, what should happen is:

  1. InsertReshardingsPass creates a Set before the Matmul. The output TV of that set has loop/allocation=[iStream{S},D,M/SD,K].
  2. Host IR lowering sees that and c both have iStream{S} and decides to inline/fuse allgather and matmul into the same host loop.
  3. Because allgather_out has iStream{S} in allocation, host IR lowering will generate an Allocate inside the loop for size [1,D,M/SD,K].
  4. Some downstream host IR optimization inserts Deallocate.
  5. Some downstream host IR optimization decides the number of streams (typically smaller than S) and adds control-flow dependencies so Deallocate is guaranteed to happen early enough. Otherwise, we may have all the S [1,D,M/SD,K]s alive at peak.

A suboptimal alternative is:

  1. InsertReshardingsPass creates a Set before the Matmul. The output TV of that set has loop=[iStream{S},i{D},i{M/SD},i{K}] but allocation=[i{S},i{D},i{M/SD},i{K}].
  2. same
  3. Because allgather_out has i{S} in allocation, host IR lowering will generate an Allocate outside the loop for size [S,D,M/SD,K].
  4. same
  5. doesn't matter because allgather_out is allocated outside the loop and its size won't be affected anyway

My earlier statement was describing the former, and what your patch implements is close to the latter. I guess that's where your confusion came from.

That said, does InsertReshardingsPass kick in for your unit test? https://github.com/NVIDIA/Fuser/pull/3606/files#diff-85674f0bb25ed74e0f94deeb9af9c3d9a5a1f43ce6a7b51339ab9cb25c365303R382-R385 lets host IR lowering create the allgather output TV, giving me the impression InsertReshardingsPass doesn't kick in.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this change, DistributedTransformerTest throws.

I'll run the test to understand what's going on. This PR doesn't change DistributedTransformerTest and the test passes at head, so there must be something else in this PR that triggered the throw.

continue;
}
NVF_ERROR(
ir_utils::isTvOp(expr),
"Non-tv op is not supported: ",
expr->toString());
NVF_ERROR(
expr->outputs().size() == 1,
"Resharding operations can only have one output: ",
expr->toString());
NVF_ERROR(
expr->inputs().size() == 1,
"Resharding operations can have only one input: ",
expr->toString());
auto* output = expr->outputs().at(0)->as<TensorView>();
auto* input = expr->inputs().at(0)->as<TensorView>();
auto [shard_additions, shard_deletions] = getShardingChanges(input, output);
Expand Down
2 changes: 2 additions & 0 deletions csrc/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,8 @@ static const char* parallel_type2string(ParallelType t) {
return "threadIdx.y";
case ParallelType::TIDx:
return "threadIdx.x";
case ParallelType::Stream:
return "Stream";
case ParallelType::Vectorize:
return "V";
case ParallelType::MisalignedVectorize:
Expand Down
1 change: 1 addition & 0 deletions csrc/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,7 @@ enum class ParallelType {
TIDz,
TIDy,
TIDx,
Stream,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@naoyam and @csarofeen as a non-blocking comment, I'm curious about your thoughts on naming this new parallel type because we'll use it a lot and renaming too late can be expensive.

Background: this new parallel type will be used to represent the intention to decompose a kernel/communication so they can overlap for shorter latency.

ParallelType::Stream gives readers the oversimplified impression that fine-grained kernels/communications are parallelized on distinct streams. In reality, it's hard to determine stream assignment (e.g. how many streams to create, which fine-grained kernels/communications to assign to which streams, and control-flow dependencies among kernels and communications) based on fusion IR. It's better determined based on lowered host IR. For example,

  1. For memory, we'll likely have to limit the number of in-flight communications, e.g., the rate limiter in FSDP as described in https://arxiv.org/pdf/2304.11277 section 3.4.2. And the current PR adopted a similar limitation via % numberOfStreams.
  2. Stream assignment depends on the order of host IR instructions. https://dl.acm.org/doi/pdf/10.1145/3567955.3567959 page 99 algorithm 2 is an example how complicated instruction ordering can become.
  3. Context parallelism (e.g. the ring attention algorithm in https://arxiv.org/pdf/2310.01889 page 15 figure 4) may require that a host loop serializes the fine-grained kernels so blockwise stats like rowmax and logsumexp are passed from one block to the next. Most naturally, these kernels will be assigned to the same "compute" stream and ring communications are assigned to one or multiple "communication" streams.

Alternatively, we could call this ParallelType::HostLoop, which matches the intention to parallelize fine-grained kernels/communications on distinct host loop iterations, leaving stream assignment to a downstream optimization on host IR. HostSerial might work too in analogy to the existing Serial. However, I'm open to other options.

I'm also curious what parallel type you would use/call to implement algorithms like FlashAttention-3. It tries to "overlap overall computation and data movement via warp-specialization", a similar idea. We may want to reuse the same parallel type or make the names consistent.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for listing all these interesting examples. I would be happy to discuss it. ParallelType::HostLoop sounds ok to me ; but for the sake of the discussion, at the moment, I personally still think Stream is a better naming. Indeed Stream parallel type conveys the idea that each tile is produced on a different stream, hence the notion of parallelizing across streams. On the contrary, HostLoop and HostSerial convey the wrong idea that the computation is sequential.

Imo Stream parallelization as we intend it here matches pretty well with the other parallel types. In the analogy with ParallelType::TIDx:

  • the index in the host for-loop corresponds to streamdIdx.x
  • and numberOfStream corresponds to blockDim.x.

However, let me point out a small difference, that goes in favor of your argumentation: contrarily to other parallel types, we allow numberOfStream to differ from the parallelized axis's extent. This is why we use a % numberOfStreams in the for-loop. This makes things more flexible. In particular, recall that pytorch allocation's cache does not behave well with multi-streaming, which is mainly why we want to use a limited number of streams, even on larger axis.

ParallelType::Stream gives readers the oversimplified impression that fine-grained kernels/communications are parallelized on distinct streams.

Why? At the tensor level, there is not even the notion of compute/comms. Stream parallel type could even be used on single-device situations.

  1. For memory, we'll likely have to limit the number of in-flight communications, e.g., the rate limiter in FSDP as described in https://arxiv.org/pdf/2304.11277 section 3.4.2. And the current PR adopted a similar limitation via % numberOfStreams

Warning: let me recall that, currently there will always be only one in-flight communication, regardless of the number of stream we use. This is because we only use one-process group, and because the communication are posted on the pg's internal stream. So the Streams considered here are in fact only compute streams.

I still need more time to fully process the other examples you propose. However, I am afraid that some situations do not even match the framework of "annotating a parallel type to a tensor's axis"...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both points seem to make sense to me. I don't have any particular preference at this point. The only suggestion I'd make is to make it easy to change the name in the future.

Re: % numberOfStreams. The usual approach for using a fixed number of threads/blocks with TID/BID, is to split the iter domain by the number. For example:

t0->split(i, numberOfStreams);
t0->axis(i + 1)->parallelize(ParallelType::Stream);

This would make the decision of the number of streams more explicit as a scheduling parameter.

I am afraid that some situations do not even match the framework of "annotating a parallel type to a tensor's axis"...

That may be the case, and that's also why I don't have a strong preference at this moment.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re: % numberOfStreams. The usual approach for using a fixed number of threads/blocks with TID/BID, is to split the iter domain by the number. For example:

Yes, but the difference in our case is that we want to round robin on fewer streams. IOW, we want to parallelize the full axis, but to use a limited pool of actual cuda streams at runtime.

Can it match with what you are suggesting?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Within a device, here's what we typically do. Suppose we want to parallelize something by just using a single thread block consisting of 128 threads, then:

// t0: [N]
t0->split(0, 128);
t0->axis(0)->parallelize(ParallelType::Serial);
t0->axis(1)->parallelize(ParallelType::TIDx);

This would be translated to something like:

for (i in ceilDiv(N/128)) {
  t0[i * 128  + threadIdx.x] = ...;
}

Similarly, for the host case, I'd imagine it would look like:

// t0: [N]
t0->split(0, 128);
t0->axis(0)->parallelize(ParallelType::Serial); // May want to use `ParalleType::HostSerial` to differentiate from `Serial`
t0->axis(1)->parallelize(ParallelType::Stream);
for (i in ceilDiv(N/128)) {
  launch_on(..., stream[streamIdx]);
}

Here, 0 <= streamIdx < 128.

In the above TIDx example, it doesn't need to be a static constant at the scheduling time. We could do something like:

// t0: [N]
t0->split(0, 4, /*inner_split=*/false);
t0->axis(0)->parallelize(ParallelType::Serial);
t0->axis(1)->parallelize(ParallelType::TIDx);

Note that t0->axis(1)->extent() is now symbolic. At the launch time, its actual size is automatically resolved based on the actual dimension of t0.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's also a CUDA environment variable that may be also useful:

https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-environment-variables

CUDA_DEVICE_MAX_CONNECTIONS | 1 to 32 (default is 8) | Sets the number of compute and copy engine concurrent connections (work queues) from the host to each device of compute capability 3.5 and above.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thank you, it makes a lot of sense to do as you suggest. I will think about it for a future PR.

Vectorize,
MisalignedVectorize,
Unroll,
Expand Down
59 changes: 59 additions & 0 deletions tests/cpp/test_multidevice_host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <cuda_profiler_api.h>
#include <fusion.h>
#include <host_ir/container.h>
#include <host_ir/executor.h>
Expand Down Expand Up @@ -349,6 +350,64 @@ TEST_F(P2PCommHostIrTest, CoalescedRingPairwiseExchange) {
EXPECT_TRUE(torch::allclose(ref_output, outputs.back()));
}

using OverlapDistributedMatmulTest = MultiDeviceTest;

TEST_F(OverlapDistributedMatmulTest, AG_matmul) {
constexpr int64_t M = 32768;
constexpr int64_t K = 32768;
constexpr int64_t N = 1024;
constexpr int64_t S = 8;
const int64_t D = communicator_->size();
if (M % (D * S) != 0) {
GTEST_SKIP() << "M must be a multiple of D * S, but got M = " << M
<< ", D = " << D << ", S = " << S;
}

auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

TensorView* a = makeContigTensor(4); //[S, DIDx(D), M/(S*D), K]
TensorView* b = makeContigTensor(2); //[K, N]
TensorView* c = matmul(a, b); //[S, D, M/(S*D), N]

fusion->addInput(a);
fusion->addInput(b);
fusion->addOutput(c);

auto mesh = DeviceMesh::createForNumDevices(D);
a->setDeviceMesh(mesh);
b->setDeviceMesh(mesh);
c->setDeviceMesh(mesh);

a->axis(1)->parallelize(ParallelType::DIDx);
c->axis(0)->parallelize(ParallelType::Stream);
samnordmann marked this conversation as resolved.
Show resolved Hide resolved

MultiDeviceExecutor executor(std::move(fusion), *communicator_);

auto tensor_options =
at::TensorOptions().dtype(at::kFloat).device(communicator_->device());
at::Tensor ta_unsharded = at::randn({S, D, M / (S * D), K}, tensor_options);
at::Tensor ta = ta_unsharded.slice(
1, communicator_->deviceId(), communicator_->deviceId() + 1);
at::Tensor tb = at::randn({K, N}, tensor_options);
at::Tensor tc_ref = at::matmul(ta_unsharded, tb);

std::vector<c10::IValue> inputs = {ta, tb};
at::Tensor tc;

constexpr int64_t kNumberOfIterations = 20;
constexpr int64_t kNumberOfWarmupIterations = 5;
for (auto i : c10::irange(kNumberOfIterations)) {
if (i == kNumberOfWarmupIterations) {
cudaProfilerStart();
}
tc = executor.runWithInput(inputs).at(0);
}
cudaProfilerStop();

EXPECT_TRUE(torch::allclose(tc_ref, tc));
}

} // namespace hir

} // namespace nvfuser
Loading
Loading