Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower distributed matmul to pipelined algorithm for fine-grained overlap: AG+GEMM layout #3606

Merged
merged 51 commits into from
Jan 13, 2025

Conversation

samnordmann
Copy link
Collaborator

@samnordmann samnordmann commented Dec 18, 2024

Stacked on top of

What

Lower a MatmulOp sharded on the first inner axis into a pipelined AG+GEMM algorithm achieving fine grained overlap.

We introduce a new parallel type Stream to account for this scheduling.

More precisely, this patch enables lowering the fusion:

  TensorView* a = makeContigTensor(4); //[S, DIDx(D), M/(S*d), K]
  TensorView* b = makeContigTensor(2); //[K, N]
  TensorView* c = matmul(a, b); //[S, D, M/(S*D), N]

  fusion->addInput(a);
  fusion->addInput(b);
  fusion->addOutput(c);

  auto mesh = DeviceMesh::createForNumDevices(D);
  a->setDeviceMesh(mesh);
  b->setDeviceMesh(mesh);
  c->setDeviceMesh(mesh);

  a->axis(1)->parallelize(ParallelType::DIDx);
  c->axis(0)->parallelize(ParallelType::Stream);

to the Host Ir program (obtained from dump, using NVFUSER_DUMP=host_ir)

%HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T2_g_float[iStream6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  GetCurrentStream into Stream 0
  T3_g_float[iS11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iS11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i4 ), zero_init=false, resets_to_zero=fals
e)
  T2_g_float[iStream6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[iStream6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i6 ), zero_init=fals
e, resets_to_zero=false)
  FOR i104 in iS0{i0}:
    SetCurrentStream to Stream ( i104 % numberOfStreams )
    T4_l_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = select( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = i104 )
    T5_l_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = select( T3_g_float[iS11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS11{i0}, index = i104 )
    Communication 46 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T4_l_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T5_l_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    Wait Communication 46
    T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = select( T2_g_float[iStream6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStream6{i0}, index = i104 )
    T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = matmul(T5_l_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    SetCurrentStream to Stream 0
    Synchronize Stream ( i104 % numberOfStreams )
} // %HostIrContainer

The nsight profile shows that we do achieve overlap, in a way that is comparable to the Aten overlap experiments

Screenshot 2024-12-18 at 12 08 05

@samnordmann samnordmann force-pushed the overlap/lower_matmul_to_hostir branch from bb867e8 to b517c2b Compare December 18, 2024 08:47
@samnordmann
Copy link
Collaborator Author

!test

csrc/host_ir/lower.cpp Outdated Show resolved Hide resolved
csrc/host_ir/lower.cpp Outdated Show resolved Hide resolved
csrc/host_ir/lower.cpp Outdated Show resolved Hide resolved
samnordmann added a commit that referenced this pull request Dec 23, 2024
# What

Make stream synchronization non-blocking from the CPU point of view

# Why

Needed for achieving overlap in 
- #3606

before this patch:
![Screenshot 2024-12-18 at 12 08
25](https://github.com/user-attachments/assets/f5c84282-ea85-4cb8-8a60-538cd91cfa1c)
after this patch
![Screenshot 2024-12-18 at 12 08
05](https://github.com/user-attachments/assets/25537a5d-3e33-4ff8-baf4-4f013c1ed230)


# How 

Before this patch, the host IR `Synchronize` would call
`c10::synchronize()` on the cuda stream, which makes the CPU blocks
until stream completion. With this patch, we synchronize the current
stream with a given stream through a `cudaEvent` and the API
`cudaStreamWaitEvent`.
samnordmann added a commit that referenced this pull request Dec 23, 2024
# What

adds the primitive `GetCurrentStream` to Host Ir stack.

# Why

needed for 
- #3606

The idea is that if we want to use multiple stream internally, we need
before hand to capture the user stream and to set it back to being the
active stream when returning
@samnordmann
Copy link
Collaborator Author

again, I had to add a couple of additional small fixes to account for some other tests...

@samnordmann
Copy link
Collaborator Author

!test

is_sharded = true;

if (alloc_id->isReduction()) {
is_reduction_sharded = true;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_reduction_sharded is only used here to check that there are not two axis DID-sharded. I am not convinced the checks necessarily needs to be done in this function. Another option could be to modify ShardingTest.ReductionShouldNotBeSharded

tests/cpp/test_multidevice_matmul.cpp Outdated Show resolved Hide resolved
csrc/host_ir/lower.cpp Show resolved Hide resolved
Comment on lines 28 to 31
if (HostIrLower::canLower(expr)) {
continue;
}
if (expr->outputs().size() > 1 || expr->inputs().size() > 1) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After this patch, we don't throw, we only pass.

To rephrase my previous comment, I was trying to say this is a wrong change. InsertReshardingsPass (which runs before ReorderShardedAxis) should already have decomposed each resharding expression into local expressions and resharding expressions that can be lowered to a communication (modulo the axis order which this pass tries to fix). All communications today takes one TV and produces one TV, so there's nothing wrong with the old code here to error out when seeing a multiple-I/O resharding expression.

Therefore, I was trying to understand what triggered you to make this change. Was it to work around a limitation in InsertReshardingsPass?

That is not correct, the stream axis is fully allocated.

(I brought this up but this no longer matters for the current discussion around multiple I/O. But still I'd like to point out a potential misconception so we can be on the same page for the future!)

I don't think so. A stream-parallelized IterDomain in allocation (in your unit test the same as loop and logical) means the allocation for that axis is sharded, similar to how nvFuser deals with TID and BID. For the allgather output, the allocation ought to be size [1, D, M/S/D, K] and it ought to be done inside the for loop. When we aggressively run each loop iteration on a different stream, the total allocation will be the same as [S, D, M/S/D, K]; however, SOTA tends to limit concurrent streams so allocation is less than that. E.g., a double-buffer approach allocates [2, D, M/S/D, K].

That said, I understand your current implementation fully allocates the allgather output outside the loop. It's valid just suboptimal. To represent that, I'd like the loop to be stream parallelized but allocation to not be stream parallelized. However, doing so today may run into problems as we don't support DID loop split. So I'm definitely OK with some temporary workarounds.

@samnordmann
Copy link
Collaborator Author

The CI is finally green!

But @wujingyue I'll wait your final word before merging

@@ -100,7 +100,7 @@ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges
bool isSharded(const TensorView* tv) {
bool is_sharded = false;
for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) {
if (!alloc_id->isDeviceDim()) {
if (!alloc_id->isDeviceDim() || alloc_id->isReduction()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incidental note:

.iter_type(IterType::Reduction)
, what you pointed out, is correct. resetSchedulingParams() resets the parallel type to Serial so it'll be r instead of rDID.

@@ -100,7 +100,7 @@ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges
bool isSharded(const TensorView* tv) {
bool is_sharded = false;
for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) {
if (!alloc_id->isDeviceDim()) {
if (!alloc_id->isDeviceDim() || alloc_id->isReduction()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another more subtle case occurs from InsertReshardingsPass

This is related to the isInnerResharding change in this PR, so I'll comment over there.

csrc/host_ir/lower.cpp Outdated Show resolved Hide resolved
@@ -100,7 +100,7 @@ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges
bool isSharded(const TensorView* tv) {
bool is_sharded = false;
for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) {
if (!alloc_id->isDeviceDim()) {
if (!alloc_id->isDeviceDim() || alloc_id->isReduction()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should assert that this case is not encountered

Yep, it's unfortunately one of the many places in nvFuser where a contract is not fully verified. And PRs are welcomed. In the meantime, how about moving

id->parallelize(ParallelType::Serial);
to shardAllLike? It has been the biggest source of rDID.

@samnordmann
Copy link
Collaborator Author

!test

@samnordmann
Copy link
Collaborator Author

!test

1 similar comment
@samnordmann
Copy link
Collaborator Author

!test

Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM :shipit:

csrc/host_ir/lower.cpp Outdated Show resolved Hide resolved
csrc/host_ir/lower.h Show resolved Hide resolved
@samnordmann
Copy link
Collaborator Author

!test

@samnordmann samnordmann merged commit 33366f9 into NVIDIA:main Jan 13, 2025
40 of 41 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants