diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index d4ca6daa022..a2ebccfb7b3 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -6,6 +6,9 @@ */ // clang-format on +#include +#include + #include #include #include @@ -14,11 +17,9 @@ #include #include #include +#include #include -#include -#include - namespace nvfuser { namespace { @@ -143,15 +144,16 @@ void ExpressionEvaluator::bindTensorDomain( logical_domain.size(), ", but got a tensor of rank ", t.dim()); + + std::vector logical_sizes = unshardedSizes(tv, t.sizes()); for (auto i : c10::irange(t.dim())) { auto id = logical_domain[i]; if (id->isBroadcast()) { - // DIDs are ignored for broadcast. - bind_(logical_domain[i]->extent(), 1, evaluate_validate); + bind_(id->extent(), 1, evaluate_validate); if (id->hasExpandedExtent()) { // Verify that t is also expanded NVF_ERROR( - t.size(i) == 1 || t.stride(i) == 0, + logical_sizes[i] == 1 || t.stride(i) == 0, "IterDomain ", id->toString(), " in ", @@ -159,45 +161,15 @@ void ExpressionEvaluator::bindTensorDomain( "TensorView ", tv->toString(), " has expanded extent but input tensor has size ", - t.size(i), + logical_sizes[i], " and stride ", t.stride(i), " in dimension ", i); - bind_( - logical_domain[i]->expandedExtent(), t.size(i), evaluate_validate); + bind_(id->expandedExtent(), logical_sizes[i], evaluate_validate); } } else { - if (logical_domain[i]->isDeviceDim()) { - // Currently we have the restrictions: - // (1) Devices parallelized axis extent == DeviceMesh's extent - // (2) Device parallelized axis cannot be split or merged - // Therefore, the device parallelized extents will always be allocated - // with size 1, but the symbolic axis extent is binded with the extent - // of the DeviceMesh - NVF_CHECK( - 1 == t.size(i), - "TensorView ", - tv->toString(), - getInputPosString(tv), - " IterDomain ", - id->toString(), - "is sharded and must have size 1, but input tensor has size ", - t.size(i)); - NVF_CHECK( - tv->hasDeviceMesh(), - "TV ", - tv->toString(), - getInputPosString(tv), - " has an empty DeviceMesh with DID parallelization") - bind_( - logical_domain[i]->extent(), - static_cast( - tv->getDeviceMesh().size(logical_domain[i]->getParallelType())), - evaluate_validate); - } else { - bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate); - } + bind_(id->extent(), logical_sizes[i], evaluate_validate); } } } diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index 506a2e81987..c98543a179a 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -1883,7 +1883,12 @@ void eraseInputDistinctRootDomains(Fusion* fusion) { std::vector new_alloc; new_alloc.reserve(tv->getAllocationDomain().size()); for (IterDomain* alloc_id : tv->getAllocationDomain()) { - new_alloc.push_back(replay.getReplay().at(alloc_id)); + IterDomain* new_alloc_id = replay.getReplay().at(alloc_id); + // ReplayTransformations replay transforms but not paralelization, so + // we have to manually parallelize the new allocation ID. In other + // places, parallelization is usually done through parallelizeAllLike. + new_alloc_id->parallelize(alloc_id->getParallelType()); + new_alloc.push_back(new_alloc_id); } std::vector new_loop; diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index c93c4980e85..6906861814e 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -3238,24 +3238,21 @@ bool TensorDomain::sameAs( std::string TensorDomain::toString(const int indent_size, const bool loop_only) const { std::stringstream ss; - if (nDims() == 0) { - indent(ss, indent_size) << "[ ]"; - return ss.str(); - } - indent(ss, indent_size) << "[ " << toDelimitedString(loop()) << " ]"; - if (!loop_only) { + if (loop_only) { + indent(ss, indent_size) << "[" << toDelimitedString(loop()) << "]"; + } else { + indent(ss, indent_size) + << "logical=[" << toDelimitedString(logical()) << "]" << std::endl; if (hasRoot()) { - ss << "," << std::endl; indent(ss, indent_size + 1) - << "root=[ " << toDelimitedString(root()) << " ]"; + << "root=[" << toDelimitedString(root()) << "]" << std::endl; } - ss << "," << std::endl; indent(ss, indent_size + 1) - << "logical=[ " << toDelimitedString(logical()) << " ]"; - if (!allocation_domain_.empty()) { - ss << "," << std::endl; + << "loop=[" << toDelimitedString(loop()) << "]" << std::endl; + if (hasAllocation()) { indent(ss, indent_size + 1) - << "allocation=[ " << toDelimitedString(allocation()) << " ]"; + << "allocation=[" << toDelimitedString(allocation()) << "]" + << std::endl; } } return ss.str(); diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index c1943fed8f5..24b7e582104 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -106,8 +106,8 @@ std::pair, std::vector> getShardingChanges bool isSharded(const TensorView* tv) { bool is_sharded = false; - for (IterDomain* id : tv->getLoopDomain()) { - if (!id->isDeviceDim()) { + for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { + if (!alloc_id->isDeviceDim()) { continue; } @@ -121,6 +121,51 @@ bool isSharded(const TensorView* tv) { return is_sharded; } +std::vector unshardedSizes( + const TensorView* tv, + c10::IntArrayRef sizes) { + std::vector unsharded_sizes = sizes.vec(); + + for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { + const ParallelType parallel_type = alloc_id->getParallelType(); + if (!isParallelTypeDeviceDim(parallel_type)) { + continue; + } + + const auto inputs = IterVisitor::getInputsTo( + {alloc_id}, + {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}); + NVF_ERROR( + !inputs.empty(), + "IterVisitor::getInputsTo shouldn't return empty unless `of` is empty."); + NVF_ERROR( + inputs.size() == 1, + "Failed to find the single logical input to ", + alloc_id, + ". This is likely because there's a Merge expression from logical to allocation, which isn't supported. Inputs are: ", + toDelimitedString(inputs)); + + const auto iter = std::find( + tv->getLogicalDomain().begin(), + tv->getLogicalDomain().end(), + inputs[0]); + NVF_ERROR( + iter != tv->getLogicalDomain().end(), + "The found input IterDomain isn't logical. This is likely because logical doesn't dominate allocation: ", + inputs[0]); + + // Count the number of non-reduction IterDomains before `iter`. Reduction + // IterDomains are not materialized in the at::Tensor's shape. + const auto index = std::count_if( + tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool { + return !id->isReduction(); + }); + unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type); + } + + return unsharded_sizes; +} + int64_t numDeviceDims(const TensorView* tv) { return std::count_if( tv->getLoopDomain().begin(), diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 12013e918b4..5be2e11bd15 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -7,6 +7,8 @@ // clang-format on #pragma once +#include + #include #include #include @@ -127,4 +129,43 @@ int64_t getShardedAxis(TensorView*); // Reorders a TensorView so that the DID parallelized axis are in front. void reorderDIDToFront(TensorView*); + +// Given a TensorView and the shape of a sharded tensor of which certain +// dimensions are partially allocated, returns the global shape that'll be used +// to bind to the TensorView's logical domain. This is to solve #3282 so we can +// bind a sharded tensor to a TensorView that has a DID-parallel loop domain. +// +// For example, when `tv` is +// logical: iM, iN +// allocation: iDIDx{D}, iN/D, iM +// and `sizes` is [2, 3], the returned shape will be [2, 3D]. This is because, +// according to the allocation domain, iM is fully allocated and iN is sharded +// and thus partially allocated. +// +// If the TensorView is not sharded, this function returns `sizes`. +// +// Limitations: +// - The function assumes that there are no Merges from logical to the +// DID-parallel IterDomains in allocation. Otherwise, it's unclear which logical +// dimension this DID-parallelization should be attributed to. +// - The function assumes that all Splits from logical to the DID-parallel +// IterDomains in allocation are even. This is because there are currently no +// ways to pass in the global shape. +// +// Despite these limitations, I took this approach as a shortcut to fix #3282, +// which blocked many other tasks. I'm however open to other better, long-term +// solutions. Some alternatives considered in #3282 are: +// - Try to bind `at::Tensor`s to allocation domains instead of logical. Many +// `*Op::evaluate` methods (e.g. +// https://github.com/NVIDIA/Fuser/blob/2415d904d1e9a5da7ca6fb1a55d3045bbd510341/csrc/ir/nodes.cpp#L4321-L4329) +// assume the input/output `at::Tensor`s have the same dimension order as the +// logical domain. Doing so would have to change them all. +// - Try to pass into FusionExecutorCache both logical (global) shapes and +// allocated (local) tensors for sharded TensorViews. The logical shapes would +// have to be passed through FusionKernelRuntime, FusionExecutor, +// ExpressionEvaluator, and so on, which is an API overhaul. +std::vector unshardedSizes( + const TensorView* tv, + c10::IntArrayRef sizes); + } // namespace nvfuser diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index e5023f4e25c..7e320f99a91 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -227,34 +227,36 @@ class DomainMap : public pointwise_utils::DomainMap { root_dim, " in tensor ", tv); - auto replay_exprs = StmtSort::getExprsBetween( + std::vector replay_exprs = StmtSort::getExprsBetween( {mapped_id}, {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); // Project the root id to loop id. Similar to projectIdToRFactor. - for (auto expr : replay_exprs) { - if (expr->isA()) { - // Split with factor one is not supposed to be here, reshape would map - // this to a broadcast. This is a conservative assert, we can relaxed it - // and support with mapping it to outer. - NVF_ERROR( - !expr->as()->factor()->isOneInt(), - "split with factor one is supposed to be translated to broadcast by reshape"); - if (expr->as()->in() == mapped_id) { - mapped_id = expr->as()->inner(); + for (auto* expr : replay_exprs) { + if (auto* split = dynamic_cast(expr)) { + if (split->in() == mapped_id) { + if (split->inner()->extent()->isOneInt() && + !split->outer()->extent()->isOneInt()) { + mapped_id = split->outer(); + } else { + mapped_id = split->inner(); + } } - } else if (expr->isA()) { + } else if (auto* merge = dynamic_cast(expr)) { // Merge with size-1 dimension is not supposed to be here, reshape would // map this to a squeeze. This is a conservative assert, we can relaxed // it and support with mapping it to out. NVF_ERROR( - !expr->as()->inner()->extent()->isOneInt(), + !merge->inner()->extent()->isOneInt(), "merge with size-1 dimension is supposed to be translated to squeeze by reshape"); - if (expr->as()->inner() == mapped_id) { - mapped_id = expr->as()->out(); + if (merge->inner() == mapped_id) { + mapped_id = merge->out(); + } + } else if (auto* resize = dynamic_cast(expr)) { + if (resize->in() == mapped_id) { + mapped_id = resize->out(); } - } else if (expr->isA() && expr->as()->in() == mapped_id) { - mapped_id = expr->as()->out(); } } + // Find the position of the loop id const auto& dom = tv->getLoopDomain(); for (auto i : c10::irange(dom.size())) { diff --git a/csrc/tensor_metadata.cpp b/csrc/tensor_metadata.cpp index 32fdee2de42..96f1d85fd2d 100644 --- a/csrc/tensor_metadata.cpp +++ b/csrc/tensor_metadata.cpp @@ -37,6 +37,7 @@ class ForwardTraverseFromLogicalToAlloc { // TODO: see [Allocation domain on both side of logical] return; } + auto [in_size, in_stride] = in_it->second; auto factor = ee_.evaluate(split->factor()).as(); NVF_ERROR( @@ -44,17 +45,24 @@ class ForwardTraverseFromLogicalToAlloc { "The logical domain and allocation domain of fusion input/output ", "tensors must be a one-to-one map, therefore, ", "non-divisible split is not allowed in allocation domain"); + + int64_t inner_size = 0; + int64_t outer_size = 0; + if (split->innerSplit()) { + outer_size = in_size / factor; + inner_size = factor; + } else { + outer_size = factor; + inner_size = in_size / factor; + } + NVF_ERROR(active_ids_.erase(in) == 1); + NVF_ERROR(active_ids_.emplace(inner, std::make_pair(inner_size, in_stride)) + .second); NVF_ERROR( active_ids_ - .emplace(inner, std::pair{factor, in_stride}) + .emplace(outer, std::make_pair(outer_size, in_stride * inner_size)) .second); - NVF_ERROR(active_ids_ - .emplace( - outer, - std::pair{ - in_size / factor, in_stride * factor}) - .second); } void handle(Merge* merge) { @@ -259,6 +267,10 @@ void validateAllocationSizesAndStrides( "Stride mismatch with contiguity info. ", " allocation domain: ", ir_utils::toString(alloc_dom), + ": sizes: ", + sizes, + ": strides: ", + strides, "; contiguity: ", toDelimitedString(contiguity), "; dim: ", @@ -283,10 +295,11 @@ inferAndValidateAllocationSizesAndStrides( const auto& alloc = tv->getMaybeAllocationDomain(); // active IDs and their shape and stride + std::vector logical_sizes = unshardedSizes(tv, tensor.sizes()); std::unordered_map> active_ids; int64_t dim_index = 0; for (IterDomain* id : TensorDomain::noReductions(logical)) { - active_ids[id] = {tensor.size(dim_index), tensor.stride(dim_index)}; + active_ids[id] = {logical_sizes.at(dim_index), tensor.stride(dim_index)}; dim_index++; } NVF_ERROR(dim_index == tensor.dim()); @@ -296,51 +309,27 @@ inferAndValidateAllocationSizesAndStrides( // Now active_ids should contain the final sizes and strides, unordered. We // need to put them to the correct order. - std::vector sizes; - std::vector strides; - sizes.reserve(alloc.size()); - strides.reserve(alloc.size()); + std::vector allocation_sizes; + std::vector allocation_strides; + allocation_sizes.reserve(alloc.size()); + allocation_strides.reserve(alloc.size()); for (IterDomain* id : TensorDomain::noReductions(alloc)) { if (id->isDeviceDim()) { - sizes.push_back(1); + allocation_sizes.push_back(1); } else { - sizes.push_back(active_ids.at(id).first); + allocation_sizes.push_back(active_ids.at(id).first); } - strides.push_back(active_ids.at(id).second); + allocation_strides.push_back(active_ids.at(id).second); } // Only validate final sizes and strides when we have a non-empty tensor. if (tensor.numel() != 0) { validateAllocationSizesAndStrides( - alloc, tv->getContiguity(), sizes, strides); + alloc, tv->getContiguity(), allocation_sizes, allocation_strides); } - return {std::move(sizes), std::move(strides)}; + return {std::move(allocation_sizes), std::move(allocation_strides)}; } -namespace { -std::pair, std::vector> unshardedSizesAndStrides( - TensorView* tv, - c10::IntArrayRef sizes, - c10::IntArrayRef strides) { - std::vector unsharded_sizes(sizes.size()); - std::vector unsharded_strides(strides.size()); - for (const auto i : c10::irange(sizes.size())) { - IterDomain* id = tv->getLogicalDomain()[i]; - if (id->isDeviceDim()) { - unsharded_sizes[i] = tv->getDeviceMesh().size(id->getParallelType()); - // This probably doesn't matter in practice unless a kernel accidentally - // tries to access the data on another rank. To be safe, set the stride - // to zero, analogous to an expanded broadcast dimension. - unsharded_strides[i] = 0; - } else { - unsharded_sizes[i] = sizes[i]; - unsharded_strides[i] = strides[i]; - } - } - return {unsharded_sizes, unsharded_strides}; -} -} // namespace - std::vector GetMetaData::evaluate( const ExpressionEvaluator& ee, const std::vector& inputs) const { @@ -364,22 +353,19 @@ std::vector GetMetaData::evaluate( metadata->data = input.data_ptr(); if (isSharded(tv)) { - auto [unsharded_sizes, unsharded_strides] = - unshardedSizesAndStrides(tv, input.sizes(), input.strides()); + std::vector unsharded_sizes = unshardedSizes(tv, input.sizes()); metadata->logical_size_data = std::move(unsharded_sizes); metadata->logical_size = c10::makeArrayRef(metadata->logical_size_data); - metadata->logical_stride_data = std::move(unsharded_strides); - metadata->logical_stride = c10::makeArrayRef(metadata->logical_stride_data); } else { metadata->logical_size = input.sizes(); - metadata->logical_stride = input.strides(); } + metadata->logical_stride = input.strides(); - auto [sizes, strides] = + auto [allocation_sizes, allocation_strides] = inferAndValidateAllocationSizesAndStrides(input, tv, ee); - metadata->alloc_size_data = std::move(sizes); + metadata->alloc_size_data = std::move(allocation_sizes); metadata->alloc_size = c10::makeArrayRef(metadata->alloc_size_data); - metadata->alloc_stride_data = std::move(strides); + metadata->alloc_stride_data = std::move(allocation_strides); metadata->alloc_stride = c10::makeArrayRef(metadata->alloc_stride_data); return {PolymorphicValue(std::move(struct_))}; } diff --git a/csrc/transform_replay.cpp b/csrc/transform_replay.cpp index 06e15929aa9..6c58d83528a 100644 --- a/csrc/transform_replay.cpp +++ b/csrc/transform_replay.cpp @@ -770,11 +770,13 @@ std::pair TransformReplay::replayCasP( new_contiguity.reserve(producer_rank); for (auto i : c10::irange(producer_rank)) { - IterDomain* id = producer->getAllocationDomain()[i]; + IterDomain* alloc_id = producer->getAllocationDomain()[i]; // We won't find reduction IterDomains in the map. See // AllocationDomainTest.CacheBefore. - if (auto it = p2c_map.find(id); it != p2c_map.end()) { - new_allocation_domain.push_back(it->second); + if (auto it = p2c_map.find(alloc_id); it != p2c_map.end()) { + IterDomain* new_alloc_id = it->second; + new_alloc_id->parallelize(alloc_id->getParallelType()); + new_allocation_domain.push_back(new_alloc_id); new_contiguity.push_back(producer->getContiguity()[i]); } } diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 1e1ff2eab9e..3adac90bc5e 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -340,6 +340,76 @@ TEST_F(MultiDeviceTest, Transpose) { UnorderedElementsAre(HeuristicIs(SchedulerType::Transpose))); } +TEST_F(MultiDeviceTest, LoopSplit) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto num_devices = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(num_devices); + + TensorView* in = makeContigConcreteTensor({num_devices * 3}); + in->setDeviceMesh(mesh); + fusion->addInput(in); + TensorView* out = set(in); + fusion->addOutput(out); + + for (auto* tv : {in, out}) { + tv->split(0, num_devices, /*inner_split=*/false); + tv->axis(0)->parallelize(ParallelType::DIDx); + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + + at::Tensor in_tensor = at::randn({3}, tensor_options); + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0]; + + testValidate( + executor_cache.fusion(), + {out_tensor}, + {in_tensor}, + {in_tensor}, + __LINE__, + __FILE__); +} + +TEST_F(MultiDeviceTest, LoopSplitWithReorder) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto num_devices = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(num_devices); + + TensorView* in = makeContigConcreteTensor({2, num_devices * 3}); + in->setDeviceMesh(mesh); + fusion->addInput(in); + + TensorView* out = set(in); + fusion->addOutput(out); + + // logical: i{2}, i{3D} + // allocation: iDIDx{D}, i{3}, i{2} + in->split(1, num_devices, /*inner_split=*/false); + in->reorder({{0, -1}}); + in->axis(0)->parallelize(ParallelType::DIDx); + in->setAllocationDomain(in->getLoopDomain(), true); + + out->split(1, num_devices, /*inner_split=*/false); + out->axis(1)->parallelize(ParallelType::DIDx); + out->setAllocationDomain(out->getLoopDomain(), true); + + at::Tensor in_tensor = at::randn({3, 2}, tensor_options).t(); + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0]; + + testValidate( + executor_cache.fusion(), + {out_tensor}, + {in_tensor}, + {in_tensor}, + __LINE__, + __FILE__); +} + class MultiDeviceBroadcastTest : public MultiDeviceTest, public testing::WithParamInterface {}; @@ -392,20 +462,28 @@ TEST_P(MultiDeviceBroadcastTest, Expanded) { TensorView* in = TensorViewBuilder() .dtype(DataType::Float) .contiguity({std::nullopt, true}) - .shape({3, -1}) + .shape({num_devices * 3, -1}) .expanded({true, false}) .build(); in->setDeviceMesh(mesh); - if (parallelizes_broadcast) { - in->axis(0)->parallelize(ParallelType::DIDx); - } TensorView* out = set(in); fusion->addInput(in); fusion->addOutput(out); + if (parallelizes_broadcast) { + for (auto* tv : {in, out}) { + tv->split(0, num_devices, /*inner_split=*/false); + tv->axis(0)->parallelize(ParallelType::DIDx); + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + } + FusionExecutorCache executor_cache(std::move(fusion)); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor in_tensor = at::randn({8}, options).as_strided({3, 8}, {0, 1}); + at::Tensor in_tensor = + at::randn({8}, options) + .as_strided( + {parallelizes_broadcast ? 3 : num_devices * 3, 8}, {0, 1}); at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0]; testValidate( executor_cache.fusion(), {out_tensor}, {in_tensor}, __LINE__, __FILE__); diff --git a/tests/cpp/test_sharding.cpp b/tests/cpp/test_sharding.cpp index e996b80be41..62d14cf43d5 100644 --- a/tests/cpp/test_sharding.cpp +++ b/tests/cpp/test_sharding.cpp @@ -23,26 +23,44 @@ namespace nvfuser { using ShardingTest = NVFuserFixtureParamTest; -// TODO: This test checks that isSharded generates an error when a split/merged -// axis is parallelized with DIDx. Update when this restriction is lifted. -TEST_F(ShardingTest, IsSharded) { +TEST_F(ShardingTest, LogicalIsSharded) { Fusion fusion; FusionGuard fg(&fusion); - TensorView* a = makeSymbolicTensor(3); - a->axis(2)->parallelize(ParallelType::DIDx); - a->split(0, 4); - EXPECT_TRUE(isSharded(a)) << "DIDx on logical domain"; + TensorView* x = makeSymbolicTensor(3); + x->axis(2)->parallelize(ParallelType::DIDx); + x->split(0, 4); - TensorView* b = makeSymbolicTensor(3); - b->split(1, 4); - b->axis(1)->parallelize(ParallelType::DIDx); - EXPECT_TRUE(isSharded(b)) << "DIDx on loop domain"; - - TensorView* c = makeSymbolicTensor(3); - c->axis(0)->parallelize(ParallelType::DIDx); - c->axis(1)->parallelize(ParallelType::DIDx); - EXPECT_ANY_THROW(isSharded(c)) << "Multiple DIDx"; + EXPECT_TRUE(isSharded(x)) << "DIDx on logical domain:" << std::endl + << x->domain()->toString(0, /*loop_only=*/false); +} + +TEST_F(ShardingTest, AllocationIsSharded) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* x = makeSymbolicTensor(3); + x->split(1, 4); + x->axis(1)->parallelize(ParallelType::DIDx); + x->setAllocationDomain(x->getLoopDomain(), true); + + EXPECT_TRUE(isSharded(x)) << "DIDx on allocation domain:" << std::endl + << x->domain()->toString(0, /*loop_only=*/false); +} + +TEST_F(ShardingTest, MultipleDIDx) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* x = makeSymbolicTensor(1); + x->split(0, 2); + x->axis(0)->parallelize(ParallelType::DIDx); + x->axis(1)->parallelize(ParallelType::DIDx); + x->setAllocationDomain(x->getLoopDomain(), true); + + EXPECT_ANY_THROW(isSharded(x)) + << "Multiple DIDx:" << std::endl + << x->domain()->toString(0, /*loop_only=*/false); } TEST_F(ShardingTest, ReductionShouldNotBeSharded) {