Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unshard tensor sizes before binding. #3444

Merged
merged 18 commits into from
Nov 30, 2024
Merged
50 changes: 11 additions & 39 deletions csrc/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
*/
// clang-format on

#include <functional>
#include <iostream>

#include <debug.h>
#include <evaluator_common.h>
#include <expr_evaluator.h>
Expand All @@ -14,11 +17,9 @@
#include <ir/iostream.h>
#include <ir/utils.h>
#include <logical_domain_map.h>
#include <multidevice/utils.h>
#include <polymorphic_value.h>

#include <functional>
#include <iostream>

namespace nvfuser {

namespace {
Expand Down Expand Up @@ -143,61 +144,32 @@ void ExpressionEvaluator::bindTensorDomain(
logical_domain.size(),
", but got a tensor of rank ",
t.dim());

std::vector<int64_t> logical_sizes = unshardedSizes(tv, t.sizes());
for (auto i : c10::irange(t.dim())) {
auto id = logical_domain[i];
if (id->isBroadcast()) {
// DIDs are ignored for broadcast.
bind_(logical_domain[i]->extent(), 1, evaluate_validate);
bind_(id->extent(), 1, evaluate_validate);
if (id->hasExpandedExtent()) {
// Verify that t is also expanded
NVF_ERROR(
t.size(i) == 1 || t.stride(i) == 0,
logical_sizes[i] == 1 || t.stride(i) == 0,
"IterDomain ",
id->toString(),
" in ",
getInputPosString(tv),
"TensorView ",
tv->toString(),
" has expanded extent but input tensor has size ",
t.size(i),
logical_sizes[i],
" and stride ",
t.stride(i),
" in dimension ",
i);
bind_(
logical_domain[i]->expandedExtent(), t.size(i), evaluate_validate);
bind_(id->expandedExtent(), logical_sizes[i], evaluate_validate);
}
} else {
if (logical_domain[i]->isDeviceDim()) {
// Currently we have the restrictions:
// (1) Devices parallelized axis extent == DeviceMesh's extent
// (2) Device parallelized axis cannot be split or merged
// Therefore, the device parallelized extents will always be allocated
// with size 1, but the symbolic axis extent is binded with the extent
// of the DeviceMesh
NVF_CHECK(
1 == t.size(i),
"TensorView ",
tv->toString(),
getInputPosString(tv),
" IterDomain ",
id->toString(),
"is sharded and must have size 1, but input tensor has size ",
t.size(i));
NVF_CHECK(
tv->hasDeviceMesh(),
"TV ",
tv->toString(),
getInputPosString(tv),
" has an empty DeviceMesh with DID parallelization")
bind_(
logical_domain[i]->extent(),
static_cast<int64_t>(
tv->getDeviceMesh().size(logical_domain[i]->getParallelType())),
evaluate_validate);
} else {
bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate);
}
bind_(id->extent(), logical_sizes[i], evaluate_validate);
}
}
}
Expand Down
7 changes: 6 additions & 1 deletion csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1883,7 +1883,12 @@ void eraseInputDistinctRootDomains(Fusion* fusion) {
std::vector<IterDomain*> new_alloc;
new_alloc.reserve(tv->getAllocationDomain().size());
for (IterDomain* alloc_id : tv->getAllocationDomain()) {
new_alloc.push_back(replay.getReplay().at(alloc_id));
IterDomain* new_alloc_id = replay.getReplay().at(alloc_id);
// ReplayTransformations replay transforms but not paralelization, so
// we have to manually parallelize the new allocation ID. In other
// places, parallelization is usually done through parallelizeAllLike.
new_alloc_id->parallelize(alloc_id->getParallelType());
new_alloc.push_back(new_alloc_id);
}

std::vector<IterDomain*> new_loop;
Expand Down
23 changes: 10 additions & 13 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3238,24 +3238,21 @@ bool TensorDomain::sameAs(
std::string TensorDomain::toString(const int indent_size, const bool loop_only)
const {
std::stringstream ss;
if (nDims() == 0) {
indent(ss, indent_size) << "[ ]";
return ss.str();
}
indent(ss, indent_size) << "[ " << toDelimitedString(loop()) << " ]";
if (!loop_only) {
if (loop_only) {
indent(ss, indent_size) << "[" << toDelimitedString(loop()) << "]";
} else {
indent(ss, indent_size)
<< "logical=[" << toDelimitedString(logical()) << "]" << std::endl;
if (hasRoot()) {
ss << "," << std::endl;
indent(ss, indent_size + 1)
<< "root=[ " << toDelimitedString(root()) << " ]";
<< "root=[" << toDelimitedString(root()) << "]" << std::endl;
}
ss << "," << std::endl;
indent(ss, indent_size + 1)
<< "logical=[ " << toDelimitedString(logical()) << " ]";
if (!allocation_domain_.empty()) {
ss << "," << std::endl;
<< "loop=[" << toDelimitedString(loop()) << "]" << std::endl;
if (hasAllocation()) {
indent(ss, indent_size + 1)
<< "allocation=[ " << toDelimitedString(allocation()) << " ]";
<< "allocation=[" << toDelimitedString(allocation()) << "]"
<< std::endl;
}
}
return ss.str();
Expand Down
49 changes: 47 additions & 2 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges

bool isSharded(const TensorView* tv) {
bool is_sharded = false;
for (IterDomain* id : tv->getLoopDomain()) {
if (!id->isDeviceDim()) {
for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) {
if (!alloc_id->isDeviceDim()) {
continue;
}

Expand All @@ -121,6 +121,51 @@ bool isSharded(const TensorView* tv) {
return is_sharded;
}

std::vector<int64_t> unshardedSizes(
const TensorView* tv,
c10::IntArrayRef sizes) {
std::vector<int64_t> unsharded_sizes = sizes.vec();

for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) {
const ParallelType parallel_type = alloc_id->getParallelType();
if (!isParallelTypeDeviceDim(parallel_type)) {
continue;
}

const auto inputs = IterVisitor::getInputsTo(
{alloc_id},
{tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()});
NVF_ERROR(
!inputs.empty(),
"IterVisitor::getInputsTo shouldn't return empty unless `of` is empty.");
NVF_ERROR(
inputs.size() == 1,
"Failed to find the single logical input to ",
alloc_id,
". This is likely because there's a Merge expression from logical to allocation, which isn't supported. Inputs are: ",
toDelimitedString(inputs));

const auto iter = std::find(
tv->getLogicalDomain().begin(),
tv->getLogicalDomain().end(),
inputs[0]);
NVF_ERROR(
iter != tv->getLogicalDomain().end(),
"The found input IterDomain isn't logical. This is likely because logical doesn't dominate allocation: ",
inputs[0]);

// Count the number of non-reduction IterDomains before `iter`. Reduction
// IterDomains are not materialized in the at::Tensor's shape.
const auto index = std::count_if(
tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool {
return !id->isReduction();
});
unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type);
}

return unsharded_sizes;
}

int64_t numDeviceDims(const TensorView* tv) {
return std::count_if(
tv->getLoopDomain().begin(),
Expand Down
41 changes: 41 additions & 0 deletions csrc/multidevice/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
// clang-format on
#pragma once

#include <c10/util/ArrayRef.h>

#include <compute_at_map.h>
#include <fusion.h>
#include <id_model/id_model.h>
Expand Down Expand Up @@ -127,4 +129,43 @@ int64_t getShardedAxis(TensorView*);

// Reorders a TensorView so that the DID parallelized axis are in front.
void reorderDIDToFront(TensorView*);

// Given a TensorView and the shape of a sharded tensor of which certain
// dimensions are partially allocated, returns the global shape that'll be used
// to bind to the TensorView's logical domain. This is to solve #3282 so we can
// bind a sharded tensor to a TensorView that has a DID-parallel loop domain.
//
// For example, when `tv` is
// logical: iM, iN
// allocation: iDIDx{D}, iN/D, iM
// and `sizes` is [2, 3], the returned shape will be [2, 3D]. This is because,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there is a mistake here..
if we bind {2,3} to {N/D, M}
then M=3 and N=2D, and so according to the comment, it should return the shape corresponding to the logical domain, i.e., [3, 2D]. Am I missing something?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moreover, do we support transposition?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is correct and is consistent with the code.

ExpressionEvaluator::bindTensorDomain basically does the following

unsharded_sizes = unshardedSizes(t.sizes());
for (i : range(t.dim())) {
  bind(logical_domain[i], unsharded_sizes[i]);
}

That's also why I prefer to say we bind the unsharded sizes to the logical domain instead of allocation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The comment is correct and is consistent with the code.

I see. Imo it is error prone to silently discard transposition. We should assert that only splits have been applied, OR, we should support transposition, which shouldn't be too hard...

ExpressionEvaluator::bindTensorDomain basically does the following

unsharded_sizes = unshardedSizes(t.sizes());
for (i : range(t.dim())) {
  bind(logical_domain[i], unsharded_sizes[i]);
}

That's also why I prefer to say we bind the unsharded sizes to the logical domain instead of allocation.

I would say in this case that we bind to neither the logical nor the allocation, but to some hybrid domain where starting from the logical we only applied the splits. This is a bit counter-intuitive to me.

In your snippet above, everything is contained in the unsharded_sizes which basically embeds a mapping from allocation (or more precisely the hybrid domain I described earlier) to logical.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is error prone to silently discard transposition.

I believe code as is supports transposition. (I assume by transposition you mean TensorView::reorder). To assure you that, I added a test in the latest commit.

// according to the allocation domain, iM is fully allocated and iN is sharded
// and thus partially allocated.
//
// If the TensorView is not sharded, this function returns `sizes`.
//
// Limitations:
// - The function assumes that there are no Merges from logical to the
// DID-parallel IterDomains in allocation. Otherwise, it's unclear which logical
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
// dimension this DID-parallelization should be attributed to.
// - The function assumes that all Splits from logical to the DID-parallel
// IterDomains in allocation are even. This is because there are currently no
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
// ways to pass in the global shape.
//
// Despite these limitations, I took this approach as a shortcut to fix #3282,
// which blocked many other tasks. I'm however open to other better, long-term
// solutions. Some alternatives considered in #3282 are:
// - Try to bind `at::Tensor`s to allocation domains instead of logical. Many
// `*Op::evaluate` methods (e.g.
// https://github.com/NVIDIA/Fuser/blob/2415d904d1e9a5da7ca6fb1a55d3045bbd510341/csrc/ir/nodes.cpp#L4321-L4329)
// assume the input/output `at::Tensor`s have the same dimension order as the
// logical domain. Doing so would have to change them all.
// - Try to pass into FusionExecutorCache both logical (global) shapes and
// allocated (local) tensors for sharded TensorViews. The logical shapes would
// have to be passed through FusionKernelRuntime, FusionExecutor,
// ExpressionEvaluator, and so on, which is an API overhaul.
std::vector<int64_t> unshardedSizes(
const TensorView* tv,
c10::IntArrayRef sizes);

} // namespace nvfuser
36 changes: 19 additions & 17 deletions csrc/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,34 +227,36 @@ class DomainMap : public pointwise_utils::DomainMap {
root_dim,
" in tensor ",
tv);
auto replay_exprs = StmtSort::getExprsBetween(
std::vector<Expr*> replay_exprs = StmtSort::getExprsBetween(
{mapped_id}, {tv->getLoopDomain().begin(), tv->getLoopDomain().end()});
// Project the root id to loop id. Similar to projectIdToRFactor.
for (auto expr : replay_exprs) {
if (expr->isA<Split>()) {
// Split with factor one is not supposed to be here, reshape would map
// this to a broadcast. This is a conservative assert, we can relaxed it
// and support with mapping it to outer.
NVF_ERROR(
!expr->as<Split>()->factor()->isOneInt(),
"split with factor one is supposed to be translated to broadcast by reshape");
wujingyue marked this conversation as resolved.
Show resolved Hide resolved
if (expr->as<Split>()->in() == mapped_id) {
mapped_id = expr->as<Split>()->inner();
for (auto* expr : replay_exprs) {
if (auto* split = dynamic_cast<Split*>(expr)) {
if (split->in() == mapped_id) {
if (split->inner()->extent()->isOneInt() &&
!split->outer()->extent()->isOneInt()) {
mapped_id = split->outer();
} else {
mapped_id = split->inner();
}
}
} else if (expr->isA<Merge>()) {
} else if (auto* merge = dynamic_cast<Merge*>(expr)) {
// Merge with size-1 dimension is not supposed to be here, reshape would
// map this to a squeeze. This is a conservative assert, we can relaxed
// it and support with mapping it to out.
NVF_ERROR(
!expr->as<Merge>()->inner()->extent()->isOneInt(),
!merge->inner()->extent()->isOneInt(),
"merge with size-1 dimension is supposed to be translated to squeeze by reshape");
if (expr->as<Merge>()->inner() == mapped_id) {
mapped_id = expr->as<Merge>()->out();
if (merge->inner() == mapped_id) {
mapped_id = merge->out();
}
} else if (auto* resize = dynamic_cast<Resize*>(expr)) {
if (resize->in() == mapped_id) {
mapped_id = resize->out();
}
} else if (expr->isA<Resize>() && expr->as<Resize>()->in() == mapped_id) {
mapped_id = expr->as<Resize>()->out();
}
}

// Find the position of the loop id
const auto& dom = tv->getLoopDomain();
for (auto i : c10::irange(dom.size())) {
Expand Down
Loading