-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lower distributed matmul to pipelined algorithm for fine-grained overlap: AG+GEMM layout #3606
Lower distributed matmul to pipelined algorithm for fine-grained overlap: AG+GEMM layout #3606
Conversation
bb867e8
to
b517c2b
Compare
!test |
# What Make stream synchronization non-blocking from the CPU point of view # Why Needed for achieving overlap in - #3606 before this patch: ![Screenshot 2024-12-18 at 12 08 25](https://github.com/user-attachments/assets/f5c84282-ea85-4cb8-8a60-538cd91cfa1c) after this patch ![Screenshot 2024-12-18 at 12 08 05](https://github.com/user-attachments/assets/25537a5d-3e33-4ff8-baf4-4f013c1ed230) # How Before this patch, the host IR `Synchronize` would call `c10::synchronize()` on the cuda stream, which makes the CPU blocks until stream completion. With this patch, we synchronize the current stream with a given stream through a `cudaEvent` and the API `cudaStreamWaitEvent`.
# What adds the primitive `GetCurrentStream` to Host Ir stack. # Why needed for - #3606 The idea is that if we want to use multiple stream internally, we need before hand to capture the user stream and to set it back to being the active stream when returning
again, I had to add a couple of additional small fixes to account for some other tests... |
!test |
csrc/multidevice/utils.cpp
Outdated
is_sharded = true; | ||
|
||
if (alloc_id->isReduction()) { | ||
is_reduction_sharded = true; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_reduction_sharded
is only used here to check that there are not two axis DID-sharded. I am not convinced the checks necessarily needs to be done in this function. Another option could be to modify ShardingTest.ReductionShouldNotBeSharded
if (HostIrLower::canLower(expr)) { | ||
continue; | ||
} | ||
if (expr->outputs().size() > 1 || expr->inputs().size() > 1) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After this patch, we don't throw, we only pass.
To rephrase my previous comment, I was trying to say this is a wrong change. InsertReshardingsPass (which runs before ReorderShardedAxis) should already have decomposed each resharding expression into local expressions and resharding expressions that can be lowered to a communication (modulo the axis order which this pass tries to fix). All communications today takes one TV and produces one TV, so there's nothing wrong with the old code here to error out when seeing a multiple-I/O resharding expression.
Therefore, I was trying to understand what triggered you to make this change. Was it to work around a limitation in InsertReshardingsPass?
That is not correct, the stream axis is fully allocated.
(I brought this up but this no longer matters for the current discussion around multiple I/O. But still I'd like to point out a potential misconception so we can be on the same page for the future!)
I don't think so. A stream-parallelized IterDomain in allocation (in your unit test the same as loop and logical) means the allocation for that axis is sharded, similar to how nvFuser deals with TID and BID. For the allgather output, the allocation ought to be size [1, D, M/S/D, K]
and it ought to be done inside the for loop. When we aggressively run each loop iteration on a different stream, the total allocation will be the same as [S, D, M/S/D, K]
; however, SOTA tends to limit concurrent streams so allocation is less than that. E.g., a double-buffer approach allocates [2, D, M/S/D, K]
.
That said, I understand your current implementation fully allocates the allgather output outside the loop. It's valid just suboptimal. To represent that, I'd like the loop to be stream parallelized but allocation to not be stream parallelized. However, doing so today may run into problems as we don't support DID loop split. So I'm definitely OK with some temporary workarounds.
The CI is finally green! But @wujingyue I'll wait your final word before merging |
csrc/multidevice/utils.cpp
Outdated
@@ -100,7 +100,7 @@ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges | |||
bool isSharded(const TensorView* tv) { | |||
bool is_sharded = false; | |||
for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { | |||
if (!alloc_id->isDeviceDim()) { | |||
if (!alloc_id->isDeviceDim() || alloc_id->isReduction()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incidental note:
Line 1203 in 9ce2112
.iter_type(IterType::Reduction) |
r
instead of rDID
.
csrc/multidevice/utils.cpp
Outdated
@@ -100,7 +100,7 @@ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges | |||
bool isSharded(const TensorView* tv) { | |||
bool is_sharded = false; | |||
for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { | |||
if (!alloc_id->isDeviceDim()) { | |||
if (!alloc_id->isDeviceDim() || alloc_id->isReduction()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another more subtle case occurs from InsertReshardingsPass
This is related to the isInnerResharding change in this PR, so I'll comment over there.
csrc/multidevice/utils.cpp
Outdated
@@ -100,7 +100,7 @@ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges | |||
bool isSharded(const TensorView* tv) { | |||
bool is_sharded = false; | |||
for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { | |||
if (!alloc_id->isDeviceDim()) { | |||
if (!alloc_id->isDeviceDim() || alloc_id->isReduction()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should assert that this case is not encountered
Yep, it's unfortunately one of the many places in nvFuser where a contract is not fully verified. And PRs are welcomed. In the meantime, how about moving
id->parallelize(ParallelType::Serial); |
rDID
.
!test |
!test |
1 similar comment
!test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
!test |
Stacked on top of
GetCurrentStream
#3605What
Lower a MatmulOp sharded on the first inner axis into a pipelined AG+GEMM algorithm achieving fine grained overlap.
We introduce a new parallel type
Stream
to account for this scheduling.More precisely, this patch enables lowering the fusion:
to the Host Ir program (obtained from dump, using
NVFUSER_DUMP=host_ir
)The nsight profile shows that we do achieve overlap, in a way that is comparable to the Aten overlap experiments