Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lower distributed matmul to pipelined algorithm for fine-grained overlap: AG+GEMM layout #3606
Lower distributed matmul to pipelined algorithm for fine-grained overlap: AG+GEMM layout #3606
Changes from 35 commits
38721fe
c4ca266
b517c2b
92ab927
ed4440a
ef8f00c
36fd2be
1e9f1d0
af06de4
5e166a0
e8ffadb
741202b
353c03c
4420eb4
d0a9340
0374604
5e07ad8
b546dce
8e8b247
d5b42c2
ac20f4f
4191ecf
dfc33f2
4e1ecb9
5490085
a5c70b8
f66e97a
dba90e3
09ccfa5
8287679
301e54d
0c9b65e
3a0d827
decb055
11843d6
caf5d0b
5b6c7bd
632aa1e
4dde7d7
8f60b45
cdd9e46
6e9fe35
428600d
e70c00f
0f93a43
6b11d33
061955f
1a83908
d506dde
a367238
ea55060
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This along with several other changes is for supporting
rDID
? That's intentionally unsupported becauserDID
(unliker
) means the data only exists in one GPU and the collectives nvFuser practically uses today (e.g. allreduce and reducescatter) don't do that. I'll need to figure out whererDID
came from. It's unexpected because your test doesn't try to parallelize a reduction dimension in the first place.cc @naoyam
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This patch is not needed in my test, but indeed the case
rDID
appears in other tests which were previously added.In this case, we should assert that this case is not encountered (but it actually is in the present state). For now nothing prevents this case to occur:
Fuser/csrc/ops/arith.cpp
Line 1203 in 9ce2112
For example, take
ReduceScatter/PipelineTestTwoStages.Communication/7
withGetParam() = (NCCL, DeviceMesh{1 0}, DeviceMesh{}, true, true, true, 1, false)
,Another more subtle case occurs from
InsertReshardingsPass
. Take for exampleMultiDeviceReductionTest.UnshardedInput_ShardedOutput/symbolic_sharded_along_dim_0
. Place a break point atcsrc/preseg_passes/pre_segmenter.cpp:43
, aka just before applyingOptimizationPass<InsertReshardingsPass>
on the fusion.Before this pass, the fusion reads as
and after this pass:
and we see that
rdeviceIdx.x
appears.Ok, but anyway I think the present patch on
isSharded
and other function are still relevant. Don't you agree?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking... I'll have to run an earlier version of this branch to understand what exactly failed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, sorry about that. Let me know how I can help. We can set up a meeting if you want
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incidental note:
Fuser/csrc/ops/arith.cpp
Line 1203 in 9ce2112
r
instead ofrDID
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is related to the isInnerResharding change in this PR, so I'll comment over there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, it's unfortunately one of the many places in nvFuser where a contract is not fully verified. And PRs are welcomed. In the meantime, how about moving
Fuser/csrc/preseg_passes/propagate_shardings.cpp
Line 137 in e172781
rDID
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you code-comment this? I suspect this is to work around some limitations in insert_resharding and reorder_sharded_axis for the stream-parallelized matmul you are working on. Otherwise, all non-lowerable resharding expressions would have been decomposed.
Eventually, insert_resharding should generate the following
and reorder_sharded_axis should do nothing for the allgather because the DIDx axis is outermost allocated (note that S in the allgather output is stream-parallelized and therefore has allocation size of 1).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before the patch, the pass was throwing an error if the expr had multiple I/O. After this patch, we don't throw, we only pass. There is nothing fundamental to that. When needed, in a future PR, we could extend this pass to also support multiple I/O. But anyway, we don't rely on this pass for distributed matmul test added by this patch.
That is not correct, the stream axis is fully allocated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To rephrase my previous comment, I was trying to say this is a wrong change. InsertReshardingsPass (which runs before ReorderShardedAxis) should already have decomposed each resharding expression into local expressions and resharding expressions that can be lowered to a communication (modulo the axis order which this pass tries to fix). All communications today takes one TV and produces one TV, so there's nothing wrong with the old code here to error out when seeing a multiple-I/O resharding expression.
Therefore, I was trying to understand what triggered you to make this change. Was it to work around a limitation in InsertReshardingsPass?
(I brought this up but this no longer matters for the current discussion around multiple I/O. But still I'd like to point out a potential misconception so we can be on the same page for the future!)
I don't think so. A stream-parallelized IterDomain in allocation (in your unit test the same as loop and logical) means the allocation for that axis is sharded, similar to how nvFuser deals with TID and BID. For the allgather output, the allocation ought to be size
[1, D, M/S/D, K]
and it ought to be done inside the for loop. When we aggressively run each loop iteration on a different stream, the total allocation will be the same as[S, D, M/S/D, K]
; however, SOTA tends to limit concurrent streams so allocation is less than that. E.g., a double-buffer approach allocates[2, D, M/S/D, K]
.That said, I understand your current implementation fully allocates the allgather output outside the loop. It's valid just suboptimal. To represent that, I'd like the loop to be stream parallelized but allocation to not be stream parallelized. However, doing so today may run into problems as we don't support DID loop split. So I'm definitely OK with some temporary workarounds.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without this change,
DistributedTransformerTest
throws.no problem!
I got your point, which makes a lot of sense. The only thing I am still not sure to understand is:
In our case, according to your convention, loop axis is stream-parallelized, but allocation axis is not. Then reorder_sharded_axis should do nothing but DID is still not outermost allocated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It depends on the TV representing the allgather output.
For me, what should happen is:
[iStream{S},D,M/SD,K]
.c
both haveiStream{S}
and decides to inline/fuse allgather and matmul into the same host loop.iStream{S}
in allocation, host IR lowering will generate an Allocate inside the loop for size[1,D,M/SD,K]
.S
[1,D,M/SD,K]
s alive at peak.A suboptimal alternative is:
[iStream{S},i{D},i{M/SD},i{K}]
but allocation=[i{S},i{D},i{M/SD},i{K}]
.i{S}
in allocation, host IR lowering will generate an Allocate outside the loop for size[S,D,M/SD,K]
.allgather_out
is allocated outside the loop and its size won't be affected anywayMy earlier statement was describing the former, and what your patch implements is close to the latter. I guess that's where your confusion came from.
That said, does InsertReshardingsPass kick in for your unit test? https://github.com/NVIDIA/Fuser/pull/3606/files#diff-85674f0bb25ed74e0f94deeb9af9c3d9a5a1f43ce6a7b51339ab9cb25c365303R382-R385 lets host IR lowering create the allgather output TV, giving me the impression InsertReshardingsPass doesn't kick in.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll run the test to understand what's going on. This PR doesn't change DistributedTransformerTest and the test passes at head, so there must be something else in this PR that triggered the throw.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@naoyam and @csarofeen as a non-blocking comment, I'm curious about your thoughts on naming this new parallel type because we'll use it a lot and renaming too late can be expensive.
Background: this new parallel type will be used to represent the intention to decompose a kernel/communication so they can overlap for shorter latency.
ParallelType::Stream
gives readers the oversimplified impression that fine-grained kernels/communications are parallelized on distinct streams. In reality, it's hard to determine stream assignment (e.g. how many streams to create, which fine-grained kernels/communications to assign to which streams, and control-flow dependencies among kernels and communications) based on fusion IR. It's better determined based on lowered host IR. For example,% numberOfStreams
.Alternatively, we could call this
ParallelType::HostLoop
, which matches the intention to parallelize fine-grained kernels/communications on distinct host loop iterations, leaving stream assignment to a downstream optimization on host IR.HostSerial
might work too in analogy to the existingSerial
. However, I'm open to other options.I'm also curious what parallel type you would use/call to implement algorithms like FlashAttention-3. It tries to "overlap overall computation and data movement via warp-specialization", a similar idea. We may want to reuse the same parallel type or make the names consistent.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for listing all these interesting examples. I would be happy to discuss it.
ParallelType::HostLoop
sounds ok to me ; but for the sake of the discussion, at the moment, I personally still thinkStream
is a better naming. IndeedStream
parallel type conveys the idea that each tile is produced on a different stream, hence the notion of parallelizing across streams. On the contrary,HostLoop
andHostSerial
convey the wrong idea that the computation is sequential.Imo Stream parallelization as we intend it here matches pretty well with the other parallel types. In the analogy with
ParallelType::TIDx
:streamdIdx.x
numberOfStream
corresponds toblockDim.x
.However, let me point out a small difference, that goes in favor of your argumentation: contrarily to other parallel types, we allow
numberOfStream
to differ from the parallelized axis's extent. This is why we use a% numberOfStreams
in the for-loop. This makes things more flexible. In particular, recall that pytorch allocation's cache does not behave well with multi-streaming, which is mainly why we want to use a limited number of streams, even on larger axis.Why? At the tensor level, there is not even the notion of compute/comms.
Stream
parallel type could even be used on single-device situations.Warning: let me recall that, currently there will always be only one in-flight communication, regardless of the number of stream we use. This is because we only use one-process group, and because the communication are posted on the pg's internal stream. So the Streams considered here are in fact only compute streams.
I still need more time to fully process the other examples you propose. However, I am afraid that some situations do not even match the framework of "annotating a parallel type to a tensor's axis"...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Both points seem to make sense to me. I don't have any particular preference at this point. The only suggestion I'd make is to make it easy to change the name in the future.
Re:
% numberOfStreams
. The usual approach for using a fixed number of threads/blocks with TID/BID, is to split the iter domain by the number. For example:This would make the decision of the number of streams more explicit as a scheduling parameter.
That may be the case, and that's also why I don't have a strong preference at this moment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but the difference in our case is that we want to round robin on fewer streams. IOW, we want to parallelize the full axis, but to use a limited pool of actual cuda streams at runtime.
Can it match with what you are suggesting?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Within a device, here's what we typically do. Suppose we want to parallelize something by just using a single thread block consisting of 128 threads, then:
This would be translated to something like:
Similarly, for the host case, I'd imagine it would look like:
Here,
0 <= streamIdx < 128
.In the above TIDx example, it doesn't need to be a static constant at the scheduling time. We could do something like:
Note that
t0->axis(1)->extent()
is now symbolic. At the launch time, its actual size is automatically resolved based on the actual dimension oft0
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's also a CUDA environment variable that may be also useful:
https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#cuda-environment-variables
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thank you, it makes a lot of sense to do as you suggest. I will think about it for a future PR.