diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index 1457c7b2404..0c1c9fa982e 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -1816,11 +1816,12 @@ void eraseInputDistinctRootDomains(Fusion* fusion) { if (tv->getLoopDomain() == tv->getAllocationDomain()) { new_loop = new_alloc; } else { - NVF_ERROR( - tv->getLoopDomain() == tv->getLogicalDomain(), - tv, - " has an unexpected loop domain:\n", - tv->domain()->toString(0, /*loop_only=*/false)); + // we shouldn't assert on loopdomain of inputs, because they carry no meaning. + // NVF_ERROR( + // tv->getLoopDomain() == tv->getLogicalDomain(), + // tv, + // " has an unexpected loop domain:\n", + // tv->domain()->toString(0, /*loop_only=*/false)); new_loop = new_logical_domain; } diff --git a/csrc/id_model/indexing.cpp b/csrc/id_model/indexing.cpp index 6453926e08b..ad78b11d158 100644 --- a/csrc/id_model/indexing.cpp +++ b/csrc/id_model/indexing.cpp @@ -1027,11 +1027,6 @@ bool TensorIndexer::isSupported(Fusion* fusion) { if (auto swizzle2d = dynamic_cast(id->definition())) { reason << "Swizzle2D not supported: " << swizzle2d->toString(); break; - } else if (ir_utils::isIndexedConsumerID(tv, id)) { - reason << "Indirect indexing of consumer ID not supported: " - << tv->toString() << ", " << id->toString() << ", " - << tv->definition()->toString(); - break; } } } diff --git a/csrc/id_model/predicate_indexing.cpp b/csrc/id_model/predicate_indexing.cpp index fc790587b7b..ac798ca9b05 100644 --- a/csrc/id_model/predicate_indexing.cpp +++ b/csrc/id_model/predicate_indexing.cpp @@ -21,7 +21,9 @@ std::vector getPredicateDomains( // domains need to be predicated. Note that the non-divisible split // info does not seem to cover non-divisible reduction rfactor // splits. - std::vector predicate_domains = consumer_tv->hasReduction() + + // NOTE: we should categories this as predicate on root, I think reduction is similar to how scatter is being represented + std::vector predicate_domains = (consumer_tv->hasReduction() || expr->isA()) ? consumer_tv->getMaybeRootDomain() : consumer_tv->getLogicalDomain(); @@ -51,6 +53,16 @@ std::vector getPredicateDomains( } } + // NOTE: we don't need to predicate on the ScatterGather ID. we should probably remove it. + if (expr->isA()) { + predicate_domains.erase( + std::remove_if( + predicate_domains.begin(), + predicate_domains.end(), + [&](IterDomain* id) -> bool { return id == expr->as()->getConsumerLogicalID(); }), + predicate_domains.end()); + } + return predicate_domains; } diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index 056cfc285df..e5c1e262016 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -1739,13 +1739,18 @@ std::vector Index::getStrides(TensorView* tv) { std::vector Index::getConsumerAllocationIndices( const TensorView* tv, const std::vector& loops, - const IndexFromIdGraph& index_from_id_graph) { + const IndexFromIdGraph& index_from_id_graph, + const std::unordered_map& override_index) { const auto& alloc_dom = tv->getMaybeAllocationDomain(); auto indexing = index_from_id_graph.index; std::vector alloc_inds( alloc_dom.size(), GpuLower::current()->kernel()->zeroVal()); - for (const auto i : arange(alloc_dom.size())) { + for (const auto i : arange((int)alloc_dom.size())) { + if (override_index.count(i)) { + alloc_inds[i] = nullptr; + continue; + } // See a comment in indexing to allocation domains in // getGlobalProducerIndex. if (alloc_dom[i]->isReduction() || alloc_dom[i]->isBroadcast() || @@ -1898,7 +1903,7 @@ std::vector Index::getGlobalConsumerStridedIndices( // if we need to override index, we need to generate the index from each // allocation axis firstly. auto alloc_inds = - getConsumerAllocationIndices(consumer_tv, loops, index_from_id_graph); + getConsumerAllocationIndices(consumer_tv, loops, index_from_id_graph, override_index); // Global striding auto vectorize_shift = @@ -1908,8 +1913,10 @@ std::vector Index::getGlobalConsumerStridedIndices( for (const auto i : arange(alloc_inds.size())) { auto override_it = override_index.find((int)i); if (override_it != override_index.end()) { + NVF_ERROR(alloc_inds[i] == nullptr); alloc_inds[i] = override_it->second; } + NVF_ERROR(alloc_inds[i] != nullptr); if (alloc_inds[i]->isZeroInt()) { continue; } else { @@ -2326,7 +2333,7 @@ kir::TensorIndex* Index::getConsumerIndex( bool generate_pointer, DataType as_type) { Val* index = nullptr; - if (!ir_utils::hasRootToLoopLinearTransformations(consumer) || + if (!ir_utils::hasRootToLoopLinearTransformations(consumer, override_index) || ir_utils::isCpAsyncBulkLoad(consumer->definition()) || GpuLower::current()->idModelOptions().consumerIndex() || GpuLower::current()->tmemInfo().hasTMemTensor()) { diff --git a/csrc/index_compute.h b/csrc/index_compute.h index ffdd1845c96..2ea536ed3c5 100644 --- a/csrc/index_compute.h +++ b/csrc/index_compute.h @@ -450,7 +450,8 @@ class Index { static std::vector getConsumerAllocationIndices( const TensorView* tv, const std::vector& loops, - const IndexFromIdGraph& index_from_id_graph); + const IndexFromIdGraph& index_from_id_graph, + const std::unordered_map& override_index = {}); // get the allocation indices of a producer tensor static std::vector getProducerAllocationIndices( diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 8fe24638f37..f94b75e919b 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -443,7 +443,8 @@ class TensorDomain : public Val { std::vector allocation, std::vector loop_domain, std::vector> contiguity = {}, - std::vector additional_ids = {}); + std::vector additional_ids = {}, + std::vector no_loop_ids = {}); TensorDomain(IrBuilderPasskey, const TensorDomain* src); @@ -625,6 +626,13 @@ class TensorDomain : public Val { return additional_ids_; } + // Additional IDs that are not on the path from one of + // root/logical/allocation/loop domain to another. We need to keep track of + // these IDs to ensure that we can find all paths/IDs of interest. + const std::vector& noLoopIDs() const { + return no_loop_ids_; + } + // Set the loop domain of this TensorDomain. void setLoopDomain(std::vector new_loop_domain); @@ -739,6 +747,7 @@ class TensorDomain : public Val { // setLoopDomain std::vector initial_loop_domain_; std::vector additional_ids_; + std::vector no_loop_ids_; std::vector no_bcast_domain_; std::vector no_reduction_domain_; diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index e92314778a0..5c2e00bfaaa 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -275,6 +275,10 @@ class ScatterOp : public Expr { IterDomain* getIndexedID() const; + IterDomain* getConsumerLoopID() const; + + IterDomain* getConsumerLogicalID() const; + ScatterOpType getScatterOpType() const { return attribute(1); } diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index dbb48c6f8c1..234b9dee717 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -315,7 +315,15 @@ std::string ScatterOp::toInlineString(int indent_size) const { } IterDomain* ScatterOp::getIndexedID() const { - return ir_utils::getTvOutput(this)->getLogicalDomain().at(dim()); + return ir_utils::getTvOutput(this)->getRootDomain().at(dim()); +} + +IterDomain* ScatterOp::getConsumerLoopID() const { + return ir_utils::getTvOutput(this)->getRootDomain().at(dim() + 1); +} + +IterDomain* ScatterOp::getConsumerLogicalID() const { + return ir_utils::getTvOutput(this)->getRootDomain().at(dim()); } std::vector ScatterOp::evaluate( @@ -3134,7 +3142,8 @@ void validateContiguity( void validateLoopDomain( const std::vector& logical_domain, const std::vector& loop_domain, - const std::vector& additional_ids) { + const std::vector& additional_ids, + const std::vector& no_reference_ids) { // Skip if there's any symbolic ID if (std::any_of( logical_domain.begin(), @@ -3159,8 +3168,16 @@ void validateLoopDomain( reference.insert( reference.end(), additional_ids.begin(), additional_ids.end()); + std::vector covered_loop_domain; + covered_loop_domain.reserve(loop_domain.size() + no_reference_ids.size()); + covered_loop_domain.insert( + covered_loop_domain.end(), loop_domain.begin(), loop_domain.end()); + // no_reference_ids are also considered part of the loop domain + covered_loop_domain.insert( + covered_loop_domain.end(), no_reference_ids.begin(), no_reference_ids.end()); + auto [redundant_ids, _, unreachable_reference_ids] = - ir_utils::compareDomainWithReference(loop_domain, reference); + ir_utils::compareDomainWithReference(covered_loop_domain, reference); auto empty_or_broadcast = [](const auto& ids) { return std::all_of(ids.begin(), ids.end(), [](IterDomain* id) { @@ -3177,6 +3194,8 @@ void validateLoopDomain( empty_or_broadcast(unreachable_reference_ids), "Not all logical IDs are covered by loop domain. Loop: ", toDelimitedString(loop_domain), + ". no loop logical IDs: ", + toDelimitedString(no_reference_ids), ". Unreachable logical IDs: ", toDelimitedString(unreachable_reference_ids)); } @@ -3251,7 +3270,7 @@ TensorDomain::TensorDomain( NVF_CHECK( loop_domain_.empty() == logical_domain_.empty(), "logical domain and loop domain can only be both empty or neither empty"); - validateLoopDomain(logical_domain_, loop_domain_, additional_ids_); + validateLoopDomain(logical_domain_, loop_domain_, additional_ids_, no_loop_ids_); // resetDomains initializes other member variables, required by clang-tidy resetDomains(); @@ -3275,7 +3294,7 @@ TensorDomain::TensorDomain( NVF_CHECK( loop_domain_.empty() == logical_domain_.empty(), "logical domain and loop domain can only be both empty or neither empty"); - validateLoopDomain(logical_domain_, loop_domain_, additional_ids_); + validateLoopDomain(logical_domain_, loop_domain_, additional_ids_, no_loop_ids_); if (!root_domain_.empty()) { ir_utils::validateDomainEquivalence( logical_domain_, root_domain_, additional_ids_); @@ -3292,7 +3311,8 @@ TensorDomain::TensorDomain( std::vector allocation_domain, std::vector loop_domain, std::vector> contiguity, - std::vector additional_ids) + std::vector additional_ids, + std::vector no_loop_ids) : Val(passkey, ValType::TensorDomain, DataType::Null), root_domain_(std::move(root_domain)), logical_domain_(std::move(logical_domain)), @@ -3300,6 +3320,7 @@ TensorDomain::TensorDomain( loop_domain_(std::move(loop_domain)), initial_loop_domain_(loop_domain_), additional_ids_(std::move(additional_ids)), + no_loop_ids_(std::move(no_loop_ids)), contiguity_( contiguity.empty() ? getContiguityFilledWith(maybeAllocation(), false) : std::move(contiguity)) { @@ -3308,11 +3329,11 @@ TensorDomain::TensorDomain( NVF_CHECK( loop_domain_.empty() == logical_domain_.empty(), "logical domain and loop domain can only be both empty or neither empty"); - validateLoopDomain(logical_domain_, loop_domain_, additional_ids_); - if (!root_domain_.empty()) { - ir_utils::validateDomainEquivalence( - logical_domain_, root_domain_, additional_ids_); - } + validateLoopDomain(logical_domain_, loop_domain_, additional_ids_, no_loop_ids_); + // if (!root_domain_.empty()) { + // ir_utils::validateDomainEquivalence( + // logical_domain_, root_domain_, additional_ids_); + // } if (!allocation_domain_.empty()) { ir_utils::validateDomainEquivalence( logical_domain_, allocation_domain_, additional_ids_); @@ -3330,6 +3351,7 @@ TensorDomain::TensorDomain(IrBuilderPasskey passkey, const TensorDomain* src) loop_domain_(src->loop_domain_), initial_loop_domain_(src->initial_loop_domain_), additional_ids_(src->additional_ids_), + no_loop_ids_(src->no_loop_ids_), no_bcast_domain_(src->no_bcast_domain_), no_reduction_domain_(src->no_reduction_domain_), contiguity_(src->contiguity_), @@ -3343,6 +3365,7 @@ TensorDomain::TensorDomain(const TensorDomain* src, IrCloner* ir_cloner) loop_domain_(ir_cloner->clone(src->loop_domain_)), initial_loop_domain_(ir_cloner->clone(src->initial_loop_domain_)), additional_ids_(ir_cloner->clone(src->additional_ids_)), + no_loop_ids_(ir_cloner->clone(src->no_loop_ids_)), no_bcast_domain_(ir_cloner->clone(src->no_bcast_domain_)), no_reduction_domain_(ir_cloner->clone(src->no_reduction_domain_)), contiguity_(src->contiguity()), @@ -3902,7 +3925,7 @@ std::pair TensorDomain::rFactor( } void TensorDomain::setLoopDomain(std::vector new_loop_domain) { - validateLoopDomain(logical(), new_loop_domain, additionalIDs()); + validateLoopDomain(logical(), new_loop_domain, additionalIDs(), noLoopIDs()); loop_domain_ = std::move(new_loop_domain); initial_loop_domain_ = loop_domain_; resetDomains(); diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 9a8fe8adee9..f9371e964d6 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -781,6 +781,29 @@ bool isIndexSelectLookupTv(const TensorView* tv) { return false; } +bool isScatterSelfTv(const TensorView* tv) { + for (auto expr : tv->uses()) { + if (expr->isA()) { + auto idx_sel = expr->as(); + if (idx_sel->selfTv() == tv) { + return true; + } + } + } + return false; +} +bool isScatterIndexTv(const TensorView* tv) { + for (auto expr : tv->uses()) { + if (expr->isA()) { + auto idx_sel = expr->as(); + if (idx_sel->indexTv() == tv) { + return true; + } + } + } + return false; +} + bool isIndexSelectIndicesTv(const TensorView* tv) { for (auto expr : tv->uses()) { if (expr->isA()) { @@ -1343,7 +1366,7 @@ bool hasUniformSiblings(Expr* expr) { return !expr->isOneOf(); } -bool hasRootToLoopLinearTransformations(const TensorView* tv) { +bool hasRootToLoopLinearTransformations(const TensorView* tv, const std::unordered_map& override_index) { auto root = tv->getMaybeRootDomain(); auto loop = tv->getLoopDomain(); std::vector loop_val(loop.begin(), loop.end()); @@ -1352,11 +1375,16 @@ bool hasRootToLoopLinearTransformations(const TensorView* tv) { std::unordered_set all_ids_set(all_ids_vec.begin(), all_ids_vec.end()); auto alloc = tv->getMaybeAllocationDomain(); auto logical = tv->getLogicalDomain(); - bool all_alloc_id_on_path = std::all_of( - alloc.begin(), alloc.end(), [&](Val* v) { return all_ids_set.count(v); }); - bool all_logical_id_on_path = - std::all_of(logical.begin(), logical.end(), [&](Val* v) { - return all_ids_set.count(v); + + bool all_alloc_id_on_path = std::ranges::all_of( + nvfuser::views::enumerate_view(alloc), [&](const auto& enumerator) { + const auto& [index, v] = enumerator; + return override_index.count(index) || all_ids_set.count(v); + }); + bool all_logical_id_on_path = std::ranges::all_of( + nvfuser::views::enumerate_view(logical), [&](const auto& enumerator) { + const auto& [index, v] = enumerator; + return override_index.count(index) || all_ids_set.count(v); }); return all_alloc_id_on_path && all_logical_id_on_path; } diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index f8682ffa302..fe8f848073c 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -440,6 +440,9 @@ bool isIndexSelectLookupTv(const TensorView* tv); // Check if the given tv is third argment of indexSelect(lookup, dim, indices) bool isIndexSelectIndicesTv(const TensorView* tv); +bool isScatterSelfTv(const TensorView* tv); +bool isScatterIndexTv(const TensorView* tv); + bool isGatherLookupTv(const Val* tv); std::string varName(const Val* val); @@ -749,7 +752,7 @@ inline bool isMemorySharedAcross( //! Check if the given tv has a root domain -> loop domain linear //! transformation. This is a temporary check used to incrementally enable //! IdModel. Eventually, this should be removed. -bool hasRootToLoopLinearTransformations(const TensorView* tv); +bool hasRootToLoopLinearTransformations(const TensorView* tv, const std::unordered_map& override_index = {}); //! In addition to the above hasRootToLoopLinearTransformations, it //! also checks the loop domain has any extra domain diff --git a/csrc/logical_domain_map.cpp b/csrc/logical_domain_map.cpp index 8e1cc09b527..ef36aba1fd1 100644 --- a/csrc/logical_domain_map.cpp +++ b/csrc/logical_domain_map.cpp @@ -77,44 +77,61 @@ namespace { // Returns producer IDs that don't map identically to consumer. A bool is // returned indicating whether corresponding consumer IDs exists. For example, // select doesn't have a consumer ID, whereas index_select does. -std::pair, bool> getNonMappingDomainInfo( +std::tuple, std::unordered_set, bool> getNonMappingDomainInfo( const TensorView* producer_tv, const TensorView* consumer_tv) { - std::unordered_set non_mapping_ids; + std::unordered_set non_mapping_producer_ids; + std::unordered_set non_mapping_consumer_ids; bool has_consumer_id = false; if (auto sop = dynamic_cast(consumer_tv->definition())) { // indexed ID is indirectly accessed - non_mapping_ids.insert(sop->getIndexedID()); + non_mapping_producer_ids.insert(sop->getIndexedID()); has_consumer_id = false; } else if ( auto sop = dynamic_cast(consumer_tv->definition())) { // indexed ID is indirectly accessed if (producer_tv == sop->lookupTv()) { - non_mapping_ids.insert(sop->getIndexedID()); + non_mapping_producer_ids.insert(sop->getIndexedID()); has_consumer_id = true; } } else if (auto gop = dynamic_cast(consumer_tv->definition())) { // indexed ID is indirectly accessed if (producer_tv == gop->lookupTv()) { - non_mapping_ids.insert(gop->getIndexedID()); + non_mapping_producer_ids.insert(gop->getIndexedID()); has_consumer_id = true; } + } else if (auto sop = dynamic_cast(consumer_tv->definition())) { + if (producer_tv == sop->selfTv()) { + non_mapping_consumer_ids.insert(sop->getConsumerLoopID()); + has_consumer_id = false; + } else if (producer_tv == sop->indexTv()) { + non_mapping_consumer_ids.insert(sop->getConsumerLogicalID()); + has_consumer_id = false; + } else if (producer_tv == sop->srcTv()) { + non_mapping_consumer_ids.insert(sop->getConsumerLogicalID()); + // FIXME: we are actually doing a put_along_axis, instead of scatter. + // The problem with scatter is that: + // srcTV->getLogicalDomain().at(sop->dim()) doesn't map to anything. So it's rejected by pointwise scheduler + // non_mapping_consumer_ids.insert(sop->getConsumerLoopID()); + // non_mapping_producer_ids.insert(producer_tv->getLogicalDomain().at(sop->dim())); + has_consumer_id = false; + } } else if ( auto iaop = dynamic_cast(consumer_tv->definition())) { // see [ Note -- IndexPutAccumulateOp semantics ] if (producer_tv == iaop->indexTv()) { // Indexing ID of index tv do not map to output. - non_mapping_ids.insert(iaop->getIndexingID()); + non_mapping_producer_ids.insert(iaop->getIndexingID()); has_consumer_id = true; } else if (producer_tv == iaop->valueTv()) { // indexing ID of value tv do not map to output. - non_mapping_ids.insert(iaop->getIndexingIDOfValue()); + non_mapping_producer_ids.insert(iaop->getIndexingIDOfValue()); has_consumer_id = true; } } - return std::make_pair(non_mapping_ids, has_consumer_id); + return {non_mapping_producer_ids, non_mapping_consumer_ids, has_consumer_id}; } } // namespace @@ -135,7 +152,7 @@ std::unordered_map PairwiseLogicalDomainMap::map( squeeze_flags = sop->getSqueezeDimFlags(); } - auto [non_mapping_producer_id, has_consumer_of_indexed_id] = + auto [non_mapping_producer_id, non_mapping_consumer_id, has_consumer_of_indexed_id] = getNonMappingDomainInfo(producer_tv_, consumer_tv_); std::unordered_map dom_map; @@ -375,6 +392,14 @@ std::unordered_map PairwiseLogicalDomainMap::map( } } + // quick and dirty mapping + if (non_mapping_consumer_id.count(consumer_id) != 0) { + if (!has_consumer_of_indexed_id) { + itc++; + continue; + } + } + // Condition 2: Different extents if (auto gop = dynamic_cast(consumer_tv_->definition()); gop != nullptr && !gop->exactSizes() && diff --git a/csrc/logical_domain_map.h b/csrc/logical_domain_map.h index 8a0fbc5e2fd..db932d77f4f 100644 --- a/csrc/logical_domain_map.h +++ b/csrc/logical_domain_map.h @@ -470,6 +470,10 @@ class ComputeAtLogicalDomainMapBuilder : private BackwardVisitor { mapPointwiseLikeOp(op); } + void handle(ScatterOp* op) override { + mapPointwiseLikeOp(op); + } + void handle(ReductionOp* op) override { mapPointwiseLikeOp(op); } diff --git a/csrc/mutator.cpp b/csrc/mutator.cpp index 88405c92c2f..bbeb1814417 100644 --- a/csrc/mutator.cpp +++ b/csrc/mutator.cpp @@ -157,6 +157,7 @@ void OptOutMutator::mutate(TensorDomain* td) { : std::vector(); std::vector domain = updateIdVec(td->loop()); std::vector additional_ids = updateIdVec(td->additionalIDs()); + std::vector no_loop_ids_ = updateIdVec(td->noLoopIDs()); if (!mutated) { return; @@ -169,7 +170,8 @@ void OptOutMutator::mutate(TensorDomain* td) { allocation_dom, domain, td->contiguity(), - additional_ids); + additional_ids, + no_loop_ids_); registerMutation(td, mutated_val); } diff --git a/csrc/ops/indexing.cpp b/csrc/ops/indexing.cpp index cd089d6d214..8f1b320aa24 100644 --- a/csrc/ops/indexing.cpp +++ b/csrc/ops/indexing.cpp @@ -154,10 +154,6 @@ TensorView* gather(TensorView* inp, int64_t dim, TensorView* index) { for (auto idx_domain_ptr : idx_domain) { out_domain.push_back( IterDomainBuilder(idx_domain_ptr) - .iter_type( - idx_domain_ptr->getIterType() == IterType::Iteration - ? IterType::GatherScatter - : idx_domain_ptr->getIterType()) .build()); } @@ -189,20 +185,37 @@ TensorView* scatterOp( dim = wrapDim(dim, (int64_t)self_dom.size()); // The shape of output tensor is same as self tensor. + std::vector root_domain; std::vector out_domain; - for (const auto i : arange(self_dom.size())) { - out_domain.push_back( + std::vector no_loop_ids; + std::vector out_loop_domain; + for (const auto i : arange((int64_t)self_dom.size())) { + auto id = IterDomainBuilder(self_dom[i]) .iter_type( + // I think this should be the right thing to do, but our indexing isn't really working yet. + // + // (i == dim && self_dom[i]->getIterType() == IterType::Iteration) self_dom[i]->getIterType() == IterType::Iteration ? IterType::GatherScatter : self_dom[i]->getIterType()) - .build()); + .build(); + out_domain.push_back(id); + root_domain.push_back(id); + if (i == dim) { + // pushing loop domain at `dim() + 1` + auto loop_id = IterDomainBuilder(idx_dom[dim]).build(); + root_domain.push_back(loop_id); + no_loop_ids.push_back(id); + out_loop_domain.push_back(loop_id); + } else { + out_loop_domain.push_back(id); + } } TensorView* out_tensor = IrBuilder::create( IrBuilder::create( - out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)), + root_domain, out_domain, out_domain, out_loop_domain, TensorDomain::getContiguityFilledWith(out_domain, true), std::vector(), no_loop_ids), self->getDataType().value()); IrBuilder::create(type, out_tensor, self, dim, index, src); diff --git a/csrc/scheduler/pointwise.cpp b/csrc/scheduler/pointwise.cpp index b8b9c74a5bf..6d13ed6f8cc 100644 --- a/csrc/scheduler/pointwise.cpp +++ b/csrc/scheduler/pointwise.cpp @@ -643,11 +643,10 @@ void schedulePointwise(Fusion* fusion, const PointwiseParams* pparams) { } TensorView* reference_tv = pointwise_utils::getReferenceTensor(fusion); - std::vector ref_orig_loop = reference_tv->getLoopDomain(); - NVF_ERROR( reference_tv != nullptr, "Could not find a fully broadcasted output to reference schedule on."); + std::vector ref_orig_loop = reference_tv->getLoopDomain(); scheduler_utils::moveNonConcretizedBroadcastInnermost(fusion, {reference_tv}); diff --git a/csrc/scheduler/registry_utils.cpp b/csrc/scheduler/registry_utils.cpp index 244202c333a..5f786b9520f 100644 --- a/csrc/scheduler/registry_utils.cpp +++ b/csrc/scheduler/registry_utils.cpp @@ -86,6 +86,20 @@ std::deque> tvChains( return tv_chains; } +bool rejectScheduleFusionOutputRequirement( + Expr* expr, + SchedulerType scheduler_type) { + if (!expr->output(0)->isFusionOutput()) { + scheduler_debug_utils::canScheduleRejectReason( + scheduler_type, + "First output of ", + expr->getOpString(), + " must be fusion output."); + return true; + } + return false; +} + bool rejectScheduleFusionInputRequirement( Expr* expr, SchedulerType scheduler_type) { @@ -182,6 +196,11 @@ bool rejectScheduleForMemoryPromotion( return true; } } + if (expr->isOneOf()) { + if (rejectScheduleFusionOutputRequirement(expr, scheduler_type)) { + return true; + } + } // Similarly, ops based resize, such as like slice, pad and cat, // may require memory promotion. Require them to be done with diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index a6ec3c5025d..436ba3ff1a8 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -58,6 +58,10 @@ bool canIgnoreIndexedInputDomainID( ->isBroadcast()) { return false; } + } else if (auto scatter = dynamic_cast(use)) { + if (root_id != scatter->selfTv()->getLogicalDomain().at(scatter->dim())) { + return false; + } } else { // If the input TV is used by any other ops return false; @@ -129,10 +133,12 @@ bool DomainMap::areAllInputIdsMappedTo(TensorView* input_tv, TensorView* tv) } } + // TODO: shouldn't this use loop domain instead?! // Erase all input concrete IDs mapped to the output domain // Ignore unresolved broadcast dimensions eraseifInputMappedThroughRootDomainAndIndexing( - in_concrete_ids, tv->getLogicalDomain()); + // in_concrete_ids, tv->getLogicalDomain()); + in_concrete_ids, tv->getLoopDomain()); return in_concrete_ids.empty(); } diff --git a/csrc/scheduler/utils.cpp b/csrc/scheduler/utils.cpp index 1cb3f4594ed..0c8fabea250 100644 --- a/csrc/scheduler/utils.cpp +++ b/csrc/scheduler/utils.cpp @@ -1259,6 +1259,7 @@ std::vector cacheInputs(Fusion* fusion, bool unroll) { for (auto tv : in_tvs) { if (tv->uses().empty() || ir_utils::isGatherLookupTv(tv) || ir_utils::isIndexSelectLookupTv(tv) || + ir_utils::isScatterSelfTv(tv) || ir_utils::isTvUsedByOpsOfType(tv)) { // Right now, tensors that are input to the select, gather and // index_select ops can't be cached as they must be in global memory. @@ -1622,6 +1623,9 @@ std::vector getInputsOutputsWithInnerDim( // scheduler prefer to use output instead of input as reference tensor. for (auto output_tv : ir_utils::filterByType(reference_tv->fusion()->outputs())) { + if (output_tv->definition()->isA()) { + continue; + } if (hasInnerDim(output_tv, vectorizable_dims, vectorize_pass)) { vectorizable_tensors.push_back(output_tv); } @@ -1631,7 +1635,8 @@ std::vector getInputsOutputsWithInnerDim( ir_utils::filterByType(reference_tv->fusion()->inputs())) { // for indexSelect(lookup_tv, dim, index_tv) op // ignore it's lookup_tv. - if (ir_utils::isGatherLookupTv(input_tv)) { + if (ir_utils::isGatherLookupTv(input_tv) || + ir_utils::isScatterSelfTv(input_tv)) { continue; } diff --git a/tests/cpp/test_scatter.cpp b/tests/cpp/test_scatter.cpp index 7706feb2221..675716aa50a 100644 --- a/tests/cpp/test_scatter.cpp +++ b/tests/cpp/test_scatter.cpp @@ -129,4 +129,107 @@ TEST_F(ScatterTest, Scatter1DIndexZerosSelfTvSameShape) { } } +TEST_F(ScatterTest, Scatter1DNoSelfInputCompilationTime) { + const std::vector> input_dims = {{8192, 128}}; + + // FIXME: this is a put_along_axis + // const std::vector> src_dims = {{1024, 128}}; + const std::vector> src_dims = {{4096, 128}}; + + const std::vector> idx_dims = {{4096, 128}}; + + for (size_t test_id = 0; test_id < idx_dims.size(); ++test_id) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + TensorView* tv_idx_1 = makeContigTensor(2, DataType::Int); + TensorView* tv_idx_2 = makeContigTensor(2, DataType::Int); + TensorView* tv_src = makeContigTensor(2); + + TensorView* tv_input = zeros({IrBuilder::create(input_dims[test_id][0], DataType::Int), IrBuilder::create(input_dims[test_id][1], DataType::Int)}, DataType::Float, false); + fusion.addInput(tv_idx_1); + fusion.addInput(tv_idx_2); + fusion.addInput(tv_src); + + auto tv_idx = add(tv_idx_1, tv_idx_2); + auto tv_out = scatter(tv_input, 0, tv_idx, tv_src); + fusion.addOutput(tv_out); + // adding this line seems to cause compilation time to explode. + // fusion.addOutput(tv_input); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options_i = + torch::TensorOptions().dtype(torch::kLong).device(at::kCUDA, 0); + + at::Tensor vals = at::randn(input_dims[test_id], options); + auto [topk_val, idx] = at::topk(vals, /*k=*/idx_dims[test_id][0], /*dim=*/0); + + at::Tensor idx_1 = at::randint(0, 24, idx_dims[test_id], options_i); + at::Tensor idx_2 = idx - idx_1; + at::Tensor src = at::randn(src_dims[test_id], options); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = + executor_cache.runFusionWithInputs({idx_1, idx_2, src}); + testValidate( + &fusion, cg_outputs, {idx_1, idx_2, src}, __LINE__, __FILE__); + } +} + +TEST_F(ScatterTest, Scatter1DManualInplace) { + const std::vector> input_dims = {{8192, 128}}; + + // FIXME: this is a put_along_axis + // const std::vector> src_dims = {{1024, 128}}; + const std::vector> src_dims = {{4096, 128}}; + + const std::vector> idx_dims = {{4096, 128}}; + + for (size_t test_id = 0; test_id < idx_dims.size(); ++test_id) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + TensorView* tv_input = makeContigTensor(2); + TensorView* tv_idx_1 = makeContigTensor(2, DataType::Int); + TensorView* tv_idx_2 = makeContigTensor(2, DataType::Int); + TensorView* tv_src = makeContigTensor(2); + + fusion.addInput(tv_input); + fusion.addInput(tv_idx_1); + fusion.addInput(tv_idx_2); + fusion.addInput(tv_src); + + auto tv_idx = add(tv_idx_1, tv_idx_2); + auto tv_out = scatter(tv_input, 0, tv_idx, tv_src); + fusion.addOutput(tv_out); + fusion.aliasOutputToInput(tv_out, tv_input, AllocationType::ReuseBuffer); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options_i = + torch::TensorOptions().dtype(torch::kLong).device(at::kCUDA, 0); + + at::Tensor vals = at::randn(input_dims[test_id], options); + auto [topk_val, idx] = at::topk(vals, /*k=*/idx_dims[test_id][0], /*dim=*/0); + + at::Tensor idx_1 = at::randint(0, 24, idx_dims[test_id], options_i); + at::Tensor idx_2 = idx - idx_1; + at::Tensor src = at::randn(src_dims[test_id], options); + at::Tensor t0 = at::zeros(input_dims[test_id], options); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto cg_outputs = + executor_cache.runFusionWithInputs({t0, idx_1, idx_2, src}); + testValidate( + &fusion, cg_outputs, {t0, idx_1, idx_2, src}, __LINE__, __FILE__); + + // manual ref because I have trust issues. + at::Tensor ref_t0 = at::zeros(input_dims[test_id], options); + ref_t0.scatter_(0, idx, src); + ASSERT_TRUE(ref_t0.allclose(cg_outputs[0].as())); + + } +} + } // namespace nvfuser