-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
expand RemoveBcastSqueeze to handle unary operations between broadcast/squeeze ops #3643
base: main
Are you sure you want to change the base?
Conversation
!test |
!test |
I don't quite understand why the CI failure I'm seeing here doesn't show up on other PRs. The repro does fail on main opened #3660 for the failure. |
!test |
!test |
This reverts commit 1f38531.
!test |
!test |
Just in case, this one:
You meant that the broadcast and squeeze ops are going to be removed as they are consecutive, right? |
@@ -318,13 +339,79 @@ TensorView* maybeDoReplacement(TensorView* orig) { | |||
if (!isReplaceableExpr(second)) { | |||
return orig; | |||
} | |||
AxisOps second_ops = exprToAxisOps(second); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having a hard time to understand what this function (maybeDoReplacement
) is doing. What is the parameter assumed to be? What is supposed to be returned?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think maybeDoReplacement
is trying to merge tv->first->second->orig
as a tv->merged->new_out
when both first and second are replaceable exprs.
i.e. when we have tv->broadcast->squeeze
, we might be able to just cancel the two and ended up returning a tv
directly.
The function returns the new_out
after the replay. The logic here is that:
if the returned pointer is different from orig
, it would consider a replacement has happened and would try to the same loop with new_out
;
if the returned pointer is the same as orig
, merge failed, it would skip second
here and move on and push inputs to second
as new candidate as orig
in the stack.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the added logic here is, when we try to swap tv->first->second->orig
as tv->replayed_second->replayed_first
, we return replayed_second->output(0)
.
Even though we are not merging two consecutive replaceable operations, by returning replayed_second->output(0)
instead of orig
, we kept replayed_second
as the candidate for the iteration, effectively skipped unary-op first
from preventing us merging neighboring replaceable operations.
return orig; | ||
} | ||
|
||
// make sure we preserve the allcoation domain on second->output(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why does the allocation domain matter?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Answered in the example below. I think I can use another comment here as well.
|
||
// make sure we preserve the allcoation domain on second->output(0) | ||
// initializing alloc_domain permutation of second output. | ||
auto second_out_tv = second->output(0)->as<TensorView>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this second_out_tv
always the same as orig
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I now realized that I could have just used orig
instead.
I'm not against the approach of this PR, but it's much more complicated than I thought. I think if we could just remove a sequence of |
Yes. Thanks for catching that. updated. |
Yes the extra cast is added because of the trivial reduction. The alternative is just re-order the two passes, as well as your suggested pattern matching. But this feels like a little bit more robust. |
std::vector<IterDomain*> tv3_nhwc = { | ||
tv3->axis(0), tv3->axis(2), tv3->axis(3), tv3->axis(1)}; | ||
tv3->setAllocationDomain(tv3_nhwc, true); | ||
fusion.addOutput(tv3); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the reason why we care about allocation domain.
i.e. tv1->relu->tv2->squeeze->tv3
. Here tv3 has an allocation domain that's a permutation.
when we replace it as tv1->replayed_squeeze->tv4->replayed_relu->tv5
. We need to ensure that tv5
has the same allocation domain as with tv3
, otherwise we are going to change the semantics and return an output with the wrong memory format.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not saying we should ignore the allocation domain. I just don't see why having an allocation domain can interfere this translation. Why not just keep using tv3
? Or, it should also be possible to reproduce the same allocation domain with tv5
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mistaken what you meant in your earlier question!
Why not just keep using tv3?
By keep using tv3
, do you mean that I can have it replayed as tv1->replayed_squeeze->tv4->replayed_relu->tv3
, I didn't realized that I can just re-use tv3 here, without needing to create a clone of it. Let me try that...
Or, it should also be possible to reproduce the same allocation domain with tv5.
Yes. I was just trying to keep it simple. If we want to support general transformations, I think I can just do the same replay I did in #3644 https://github.com/NVIDIA/Fuser/pull/3644/files#diff-abe2e10add90523ff6b18e1dc50da46762420e1011078ba47ab52140dc213b6fR80-R85.
FusionExecutorCache executor_cache(std::move(fusion_ptr)); | ||
auto outputs = executor_cache.runFusionWithInputs(inputs); | ||
// validate output permutation is preserved | ||
ASSERT_TRUE(outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without the allocation domain update, this check would fail and it's optimization pass changing the user intended behavior.
It' certainly more generalized, but do we know if there's any actual case where this and #3466 would help besides the straight-line pattern of broadcast, cast, squeeze and cast? I'm just feeling it seems a little over-engineering for a simple task like removing the particular pattern if there isn't any other impact. |
If I'm hearing this correctly, the concern is the impact of the aggressive reorder? That's a hard argument for me to win over. But let me give it a shot. In the backward graph, we could encounter this squeeze + broadcast pattern pretty often and they might not naturally always cancel each other out. See the grad rule for broadcast_in_dim in thunder. In the origin issue #3635, we have the pattern here vvv
I think the real trouble-some pattern here is But this might not be enough, the
In that example, we do not have another broadcast before T38, but if that is the case, we would want to be able to re-order the |
I'm just commenting from the principle of KISS. I'd just create a new pass that would detect the four-op pattern and remove them. That'd be it.
IIUC, it doesn't seem to matter if there's both a real reduction and a squeeze. It seems what you're suggesting is the capability of moving squeeze ops would be helpful even without a preceding broadcast op. Assuming my understanding is correct, I wouldn't disagree with the idea, but I am also not clear why we shouldn't just leave the squeeze op there. Does the reduction scheduler have any issue with it? If so, should we focus on fixing it rather than avoiding it? If there's no particular issue, why would the benefit of the reordering outweigh the optimization pass getting even more complicated? |
Fixes #3635
Existing RemoveBcastSqueeze optimization pass only handles consecutive broadcast/squeeze ops. This PR expand the pass to handle cases where simple unary operations are separating broadcast/squeeze ops.
e.g.
In this PR, we update it so that, when we see a pattern where a
replaceable
expr is followed by a unary op. we swap the two operations, effectively pushingreplaceable
exprs towards inputs, hoping they will encounter anotherreplaceable
exprs and we would be able to merge them together.In the example above, we'll replace
T2 = relu(T1); T3 = squeeze(T2)
asT2 = squeeze(T1); T3 = relu(T2)
. In the next iteration, we'll be able to merge the broadcast and squeeze op together, since they are now consecutive operations.