Skip to content

Commit

Permalink
Unify unshardedSizes.
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Nov 25, 2024
1 parent cfad59b commit 98ce7eb
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 108 deletions.
54 changes: 5 additions & 49 deletions csrc/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,44 +128,6 @@ void validateValWithConcreteValue(
}
}

std::vector<int64_t> unshardedSizes(
const TensorView* tv,
c10::IntArrayRef sizes) {
std::vector<int64_t> unsharded_sizes = sizes.vec();

for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) {
const ParallelType parallel_type = alloc_id->getParallelType();
if (!isParallelTypeDeviceDim(parallel_type)) {
continue;
}

const auto inputs = IterVisitor::getInputsTo(
{alloc_id},
{tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()});
if (inputs.empty()) {
// FIXME: is this even possible? Logical ought to dominate allocation.
continue;
}
NVF_ERROR(inputs.size() == 1);

const auto iter = std::find(
tv->getLogicalDomain().begin(),
tv->getLogicalDomain().end(),
inputs[0]);
if (iter == tv->getLogicalDomain().end()) {
// FIXME: is this even possible? Logical ought to dominate allocation.
continue;
}
const auto index = std::count_if(
tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool {
return !id->isReduction();
});
unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type);
}

return unsharded_sizes;
}

} // namespace

void ExpressionEvaluator::bindTensorDomain(
Expand All @@ -183,37 +145,31 @@ void ExpressionEvaluator::bindTensorDomain(
", but got a tensor of rank ",
t.dim());

std::vector<int64_t> sizes;
if (isSharded(tv)) {
sizes = unshardedSizes(tv, t.sizes());
} else {
sizes = t.sizes().vec();
}

std::vector<int64_t> logical_sizes = unshardedSizes(tv, t.sizes());
for (auto i : c10::irange(t.dim())) {
auto id = logical_domain[i];
if (id->isBroadcast()) {
bind_(id->extent(), 1, evaluate_validate);
if (id->hasExpandedExtent()) {
// Verify that t is also expanded
NVF_ERROR(
sizes[i] == 1 || t.stride(i) == 0,
logical_sizes[i] == 1 || t.stride(i) == 0,
"IterDomain ",
id->toString(),
" in ",
getInputPosString(tv),
"TensorView ",
tv->toString(),
" has expanded extent but input tensor has size ",
sizes[i],
logical_sizes[i],
" and stride ",
t.stride(i),
" in dimension ",
i);
bind_(id->expandedExtent(), sizes[i], evaluate_validate);
bind_(id->expandedExtent(), logical_sizes[i], evaluate_validate);
}
} else {
bind_(id->extent(), sizes[i], evaluate_validate);
bind_(id->extent(), logical_sizes[i], evaluate_validate);
}
}
}
Expand Down
45 changes: 45 additions & 0 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,51 @@ bool isSharded(const TensorView* tv) {
return is_sharded;
}

std::vector<int64_t> unshardedSizes(
const TensorView* tv,
c10::IntArrayRef sizes) {
std::vector<int64_t> unsharded_sizes = sizes.vec();

for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) {
const ParallelType parallel_type = alloc_id->getParallelType();
if (!isParallelTypeDeviceDim(parallel_type)) {
continue;
}

const auto inputs = IterVisitor::getInputsTo(
{alloc_id},
{tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()});
NVF_ERROR(
!inputs.empty(),
"IterVisitor::getInputsTo shouldn't return empty unless `of` is empty.");
NVF_ERROR(
inputs.size() == 1,
"Failed to find the single logical input to ",
alloc_id,
". This is likely because there's a Merge expression from logical to allocation, which isn't supported. Inputs are: ",
toDelimitedString(inputs));

const auto iter = std::find(
tv->getLogicalDomain().begin(),
tv->getLogicalDomain().end(),
inputs[0]);
NVF_ERROR(
iter != tv->getLogicalDomain().end(),
"The found input IterDomain isn't logical. This is likely because logical doesn't dominate allocation: ",
inputs[0]);

// Count the number of non-reduction IterDomains before `iter`. Reduction
// IterDomains are not materialized in the at::Tensor's shape.
const auto index = std::count_if(
tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool {
return !id->isReduction();
});
unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type);
}

return unsharded_sizes;
}

int64_t numDeviceDims(const TensorView* tv) {
return std::count_if(
tv->getLoopDomain().begin(),
Expand Down
29 changes: 29 additions & 0 deletions csrc/multidevice/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
// clang-format on
#pragma once

#include <c10/util/ArrayRef.h>

#include <compute_at_map.h>
#include <fusion.h>
#include <id_model/id_model.h>
Expand Down Expand Up @@ -127,4 +129,31 @@ int64_t getShardedAxis(TensorView*);

// Reorders a TensorView so that the DID parallelized axis are in front.
void reorderDIDToFront(TensorView*);

// Given a TensorView and the shape of a sharded tensor of which certain
// dimensions are partially alloated, returns the global shape that'll be used
// to bind to the TensorView's logical domain. This is to solve #3282 so we can
// bind a sharded tensor to a TensorView that has a DID-parallel loop domain.
//
// For example, when `tv` is
// logical: iM, iN
// allocation: iDIDx{D}, iN/D, iM
// and `sizes` is [2, 3], the returned shape will be [2, 3D]. This is because,
// according to the allocation domain, iM is fully allocated and iN is sharded
// and thus partially allocated.
//
// As a degenerate case, it's fine to call this function with a non-sharded
// TensorView and tensor.
//
// Limitations:
// - The function assumes that there are no Merges from logical to the
// DID-parallel IterDomains in allocation. Otherwise, it's unclear which logical
// dimension this DID-parallelization should be attributed to.
// - The function assumes that all Splits from logical to the DID-parallel
// IterDomains in allocation are even. This is because there are currently no
// ways to pass in the global shape without an API overhaul.
std::vector<int64_t> unshardedSizes(
const TensorView* tv,
c10::IntArrayRef sizes);

} // namespace nvfuser
68 changes: 9 additions & 59 deletions csrc/tensor_metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,46 +272,6 @@ void validateAllocationSizesAndStrides(
}
}

// FIXME: strides are never changed
std::pair<std::vector<int64_t>, std::vector<int64_t>> unshardedSizesAndStrides(
TensorView* tv,
c10::IntArrayRef sizes,
c10::IntArrayRef strides) {
std::vector<int64_t> unsharded_sizes = sizes.vec();
std::vector<int64_t> unsharded_strides = strides.vec();

for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) {
const ParallelType parallel_type = alloc_id->getParallelType();
if (!isParallelTypeDeviceDim(parallel_type)) {
continue;
}

const auto inputs = IterVisitor::getInputsTo(
{alloc_id},
{tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()});
if (inputs.empty()) {
// FIXME: is this even possible? Logical ought to dominate loop.
continue;
}
NVF_ERROR(inputs.size() == 1);

const auto iter = std::find(
tv->getLogicalDomain().begin(),
tv->getLogicalDomain().end(),
inputs[0]);
if (iter == tv->getLogicalDomain().end()) {
// FIXME: is this even possible? Logical ought to dominate loop.
continue;
}
const auto index = std::count_if(
tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool {
return !id->isReduction();
});
unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type);
}

return {unsharded_sizes, unsharded_strides};
}
} // namespace

std::pair<std::vector<int64_t>, std::vector<int64_t>>
Expand All @@ -322,21 +282,12 @@ inferAndValidateAllocationSizesAndStrides(
const auto& logical = tv->getLogicalDomain();
const auto& alloc = tv->getMaybeAllocationDomain();

std::vector<int64_t> logical_sizes;
std::vector<int64_t> logical_strides;
if (isSharded(tv)) {
std::tie(logical_sizes, logical_strides) =
unshardedSizesAndStrides(tv, tensor.sizes(), tensor.strides());
} else {
logical_sizes = tensor.sizes().vec();
logical_strides = tensor.strides().vec();
}

// active IDs and their shape and stride
std::vector<int64_t> logical_sizes = unshardedSizes(tv, tensor.sizes());
std::unordered_map<IterDomain*, std::pair<int64_t, int64_t>> active_ids;
int64_t dim_index = 0;
for (IterDomain* id : TensorDomain::noReductions(logical)) {
active_ids[id] = {logical_sizes[dim_index], logical_strides[dim_index]};
active_ids[id] = {logical_sizes[dim_index], tensor.stride(dim_index)};
dim_index++;
}
NVF_ERROR(dim_index == tensor.dim());
Expand All @@ -348,6 +299,8 @@ inferAndValidateAllocationSizesAndStrides(
// need to put them to the correct order.
std::vector<int64_t> allocation_sizes;
std::vector<int64_t> allocation_strides;
allocation_sizes.reserve(alloc.size());
allocation_strides.reserve(alloc.size());
for (IterDomain* id : TensorDomain::noReductions(alloc)) {
if (id->isDeviceDim()) {
allocation_sizes.push_back(1);
Expand Down Expand Up @@ -388,22 +341,19 @@ std::vector<PolymorphicValue> GetMetaData::evaluate(
metadata->data = input.data_ptr();

if (isSharded(tv)) {
auto [unsharded_sizes, unsharded_strides] =
unshardedSizesAndStrides(tv, input.sizes(), input.strides());
std::vector<int64_t> unsharded_sizes = unshardedSizes(tv, input.sizes());
metadata->logical_size_data = std::move(unsharded_sizes);
metadata->logical_size = c10::makeArrayRef(metadata->logical_size_data);
metadata->logical_stride_data = std::move(unsharded_strides);
metadata->logical_stride = c10::makeArrayRef(metadata->logical_stride_data);
} else {
metadata->logical_size = input.sizes();
metadata->logical_stride = input.strides();
}
metadata->logical_stride = input.strides();

auto [sizes, strides] =
auto [allocation_sizes, allocation_strides] =
inferAndValidateAllocationSizesAndStrides(input, tv, ee);
metadata->alloc_size_data = std::move(sizes);
metadata->alloc_size_data = std::move(allocation_sizes);
metadata->alloc_size = c10::makeArrayRef(metadata->alloc_size_data);
metadata->alloc_stride_data = std::move(strides);
metadata->alloc_stride_data = std::move(allocation_strides);
metadata->alloc_stride = c10::makeArrayRef(metadata->alloc_stride_data);
return {PolymorphicValue(std::move(struct_))};
}
Expand Down

0 comments on commit 98ce7eb

Please sign in to comment.