Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
20582c0
see if it blows up
jjsjann123 Apr 24, 2025
131cbdb
Merge remote-tracking branch 'origin/main' into jj/embedding_ir
jjsjann123 Apr 28, 2025
b02acca
Merge remote-tracking branch 'origin/main' into jj/embedding_ir
jjsjann123 Apr 30, 2025
2673152
Merge remote-tracking branch 'origin/main' into jj/embedding_ir
jjsjann123 May 1, 2025
b7a30b7
Merge remote-tracking branch 'origin/main' into jj/embedding_ir
jjsjann123 May 2, 2025
053cd89
Merge remote-tracking branch 'origin/main' into jj/embedding_ir
jjsjann123 May 5, 2025
d695f5f
I think this would work
jjsjann123 May 5, 2025
37ca5fb
fixing build
jjsjann123 May 6, 2025
d72741c
adding mismatch loop domain
jjsjann123 May 6, 2025
5ac8d05
WIP
jjsjann123 May 6, 2025
9401147
wip
jjsjann123 May 6, 2025
6152426
WIP
jjsjann123 May 6, 2025
c19aebb
wip
jjsjann123 May 6, 2025
a08039b
err
jjsjann123 May 6, 2025
749d2c7
fix
jjsjann123 May 6, 2025
b358ff6
fixing copy no_loop_ids
jjsjann123 May 6, 2025
3b42a18
WIP
jjsjann123 May 7, 2025
c91b5cb
fixing build
jjsjann123 May 7, 2025
fe19828
errr fixing the no_loop_ids
jjsjann123 May 7, 2025
e16770f
fixing loop domain
jjsjann123 May 7, 2025
c741121
errr
jjsjann123 May 7, 2025
77af522
WIP
jjsjann123 May 7, 2025
675724b
wip
jjsjann123 May 7, 2025
ed9d9ef
ANOTHER test with bfs failure
jjsjann123 May 7, 2025
8773393
WIP
jjsjann123 May 8, 2025
bfc93ba
fixing build, no c++23
jjsjann123 May 8, 2025
0de1848
skipping iallocation indices for override ID
jjsjann123 May 8, 2025
2e89f49
fixing declaration
jjsjann123 May 8, 2025
cb82fa9
use root domain for scatter as predicate domains
jjsjann123 May 8, 2025
15a648a
removing logical id from predicate for scatter
jjsjann123 May 8, 2025
99f169c
FIXING IT!
jjsjann123 May 8, 2025
df8853c
WIP
jjsjann123 May 9, 2025
1d02e45
WIP
jjsjann123 May 9, 2025
66d5ad2
modifying the test
jjsjann123 May 9, 2025
d25c0a0
fixing scatter ID retrieval
jjsjann123 May 9, 2025
431a83b
allow mismatch src ID
jjsjann123 May 9, 2025
27b8f90
mapping
jjsjann123 May 9, 2025
bacbe12
switch to put_along_axis
jjsjann123 May 9, 2025
b055bb9
preventing vectorization for scatter
jjsjann123 May 9, 2025
909eb97
WIP
jjsjann123 May 10, 2025
4e9b720
adding another test with manually fed zero-init tensor to validate in…
jjsjann123 May 13, 2025
814fc2d
addressing my trust issue for test reference
jjsjann123 May 13, 2025
ec2abe9
Merge remote-tracking branch 'origin/main' into jj/embedding_ir
jjsjann123 May 13, 2025
f6a7a01
Merge remote-tracking branch 'origin/main' into jj/embedding_ir
jjsjann123 May 19, 2025
67856c5
Merge remote-tracking branch 'origin/main' into jj/embedding_ir
jjsjann123 May 20, 2025
b52fd7c
adding scatter to python
jjsjann123 May 28, 2025
a59acc2
i think this is necessary to fix the mapping for bcast
jjsjann123 May 30, 2025
848d67b
fixing something
jjsjann123 May 30, 2025
a253226
Merge remote-tracking branch 'origin/main' into jj/embedding_ir
jjsjann123 Jun 2, 2025
90b1de1
Merge remote-tracking branch 'origin/main' into jj/embedding_ir
jjsjann123 Jun 2, 2025
16bf64d
Merge remote-tracking branch 'origin/main' into jj/embedding_ir
jjsjann123 Jun 3, 2025
0146b7c
Merge remote-tracking branch 'origin/main' into jj/embedding_ir
jjsjann123 Jun 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1816,11 +1816,12 @@ void eraseInputDistinctRootDomains(Fusion* fusion) {
if (tv->getLoopDomain() == tv->getAllocationDomain()) {
new_loop = new_alloc;
} else {
NVF_ERROR(
tv->getLoopDomain() == tv->getLogicalDomain(),
tv,
" has an unexpected loop domain:\n",
tv->domain()->toString(0, /*loop_only=*/false));
// we shouldn't assert on loopdomain of inputs, because they carry no meaning.
// NVF_ERROR(
// tv->getLoopDomain() == tv->getLogicalDomain(),
// tv,
// " has an unexpected loop domain:\n",
// tv->domain()->toString(0, /*loop_only=*/false));

new_loop = new_logical_domain;
}
Expand Down
5 changes: 0 additions & 5 deletions csrc/id_model/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1027,11 +1027,6 @@ bool TensorIndexer::isSupported(Fusion* fusion) {
if (auto swizzle2d = dynamic_cast<Swizzle2D*>(id->definition())) {
reason << "Swizzle2D not supported: " << swizzle2d->toString();
break;
} else if (ir_utils::isIndexedConsumerID(tv, id)) {
reason << "Indirect indexing of consumer ID not supported: "
<< tv->toString() << ", " << id->toString() << ", "
<< tv->definition()->toString();
break;
}
}
}
Expand Down
14 changes: 13 additions & 1 deletion csrc/id_model/predicate_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ std::vector<IterDomain*> getPredicateDomains(
// domains need to be predicated. Note that the non-divisible split
// info does not seem to cover non-divisible reduction rfactor
// splits.
std::vector<IterDomain*> predicate_domains = consumer_tv->hasReduction()

// NOTE: we should categories this as predicate on root, I think reduction is similar to how scatter is being represented
std::vector<IterDomain*> predicate_domains = (consumer_tv->hasReduction() || expr->isA<ScatterOp>())
? consumer_tv->getMaybeRootDomain()
: consumer_tv->getLogicalDomain();

Expand Down Expand Up @@ -51,6 +53,16 @@ std::vector<IterDomain*> getPredicateDomains(
}
}

// NOTE: we don't need to predicate on the ScatterGather ID. we should probably remove it.
if (expr->isA<ScatterOp>()) {
predicate_domains.erase(
std::remove_if(
predicate_domains.begin(),
predicate_domains.end(),
[&](IterDomain* id) -> bool { return id == expr->as<ScatterOp>()->getConsumerLogicalID(); }),
predicate_domains.end());
}

return predicate_domains;
}

Expand Down
15 changes: 11 additions & 4 deletions csrc/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1739,13 +1739,18 @@ std::vector<Val*> Index::getStrides(TensorView* tv) {
std::vector<Val*> Index::getConsumerAllocationIndices(
const TensorView* tv,
const std::vector<ForLoop*>& loops,
const IndexFromIdGraph& index_from_id_graph) {
const IndexFromIdGraph& index_from_id_graph,
const std::unordered_map<int, Val*>& override_index) {
const auto& alloc_dom = tv->getMaybeAllocationDomain();
auto indexing = index_from_id_graph.index;

std::vector<Val*> alloc_inds(
alloc_dom.size(), GpuLower::current()->kernel()->zeroVal());
for (const auto i : arange(alloc_dom.size())) {
for (const auto i : arange((int)alloc_dom.size())) {
if (override_index.count(i)) {
alloc_inds[i] = nullptr;
continue;
}
// See a comment in indexing to allocation domains in
// getGlobalProducerIndex.
if (alloc_dom[i]->isReduction() || alloc_dom[i]->isBroadcast() ||
Expand Down Expand Up @@ -1898,7 +1903,7 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
// if we need to override index, we need to generate the index from each
// allocation axis firstly.
auto alloc_inds =
getConsumerAllocationIndices(consumer_tv, loops, index_from_id_graph);
getConsumerAllocationIndices(consumer_tv, loops, index_from_id_graph, override_index);

// Global striding
auto vectorize_shift =
Expand All @@ -1908,8 +1913,10 @@ std::vector<Val*> Index::getGlobalConsumerStridedIndices(
for (const auto i : arange(alloc_inds.size())) {
auto override_it = override_index.find((int)i);
if (override_it != override_index.end()) {
NVF_ERROR(alloc_inds[i] == nullptr);
alloc_inds[i] = override_it->second;
}
NVF_ERROR(alloc_inds[i] != nullptr);
if (alloc_inds[i]->isZeroInt()) {
continue;
} else {
Expand Down Expand Up @@ -2326,7 +2333,7 @@ kir::TensorIndex* Index::getConsumerIndex(
bool generate_pointer,
DataType as_type) {
Val* index = nullptr;
if (!ir_utils::hasRootToLoopLinearTransformations(consumer) ||
if (!ir_utils::hasRootToLoopLinearTransformations(consumer, override_index) ||
ir_utils::isCpAsyncBulkLoad(consumer->definition()) ||
GpuLower::current()->idModelOptions().consumerIndex() ||
GpuLower::current()->tmemInfo().hasTMemTensor()) {
Expand Down
3 changes: 2 additions & 1 deletion csrc/index_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,8 @@ class Index {
static std::vector<Val*> getConsumerAllocationIndices(
const TensorView* tv,
const std::vector<ForLoop*>& loops,
const IndexFromIdGraph& index_from_id_graph);
const IndexFromIdGraph& index_from_id_graph,
const std::unordered_map<int, Val*>& override_index = {});

// get the allocation indices of a producer tensor
static std::vector<Val*> getProducerAllocationIndices(
Expand Down
11 changes: 10 additions & 1 deletion csrc/ir/internal_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,8 @@ class TensorDomain : public Val {
std::vector<IterDomain*> allocation,
std::vector<IterDomain*> loop_domain,
std::vector<std::optional<bool>> contiguity = {},
std::vector<IterDomain*> additional_ids = {});
std::vector<IterDomain*> additional_ids = {},
std::vector<IterDomain*> no_loop_ids = {});

TensorDomain(IrBuilderPasskey, const TensorDomain* src);

Expand Down Expand Up @@ -625,6 +626,13 @@ class TensorDomain : public Val {
return additional_ids_;
}

// Additional IDs that are not on the path from one of
// root/logical/allocation/loop domain to another. We need to keep track of
// these IDs to ensure that we can find all paths/IDs of interest.
const std::vector<IterDomain*>& noLoopIDs() const {
return no_loop_ids_;
}

// Set the loop domain of this TensorDomain.
void setLoopDomain(std::vector<IterDomain*> new_loop_domain);

Expand Down Expand Up @@ -739,6 +747,7 @@ class TensorDomain : public Val {
// setLoopDomain
std::vector<IterDomain*> initial_loop_domain_;
std::vector<IterDomain*> additional_ids_;
std::vector<IterDomain*> no_loop_ids_;

std::vector<IterDomain*> no_bcast_domain_;
std::vector<IterDomain*> no_reduction_domain_;
Expand Down
4 changes: 4 additions & 0 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ class ScatterOp : public Expr {

IterDomain* getIndexedID() const;

IterDomain* getConsumerLoopID() const;

IterDomain* getConsumerLogicalID() const;

ScatterOpType getScatterOpType() const {
return attribute<ScatterOpType>(1);
}
Expand Down
47 changes: 35 additions & 12 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,15 @@ std::string ScatterOp::toInlineString(int indent_size) const {
}

IterDomain* ScatterOp::getIndexedID() const {
return ir_utils::getTvOutput(this)->getLogicalDomain().at(dim());
return ir_utils::getTvOutput(this)->getRootDomain().at(dim());
}

IterDomain* ScatterOp::getConsumerLoopID() const {
return ir_utils::getTvOutput(this)->getRootDomain().at(dim() + 1);
}

IterDomain* ScatterOp::getConsumerLogicalID() const {
return ir_utils::getTvOutput(this)->getRootDomain().at(dim());
}

std::vector<PolymorphicValue> ScatterOp::evaluate(
Expand Down Expand Up @@ -3134,7 +3142,8 @@ void validateContiguity(
void validateLoopDomain(
const std::vector<IterDomain*>& logical_domain,
const std::vector<IterDomain*>& loop_domain,
const std::vector<IterDomain*>& additional_ids) {
const std::vector<IterDomain*>& additional_ids,
const std::vector<IterDomain*>& no_reference_ids) {
// Skip if there's any symbolic ID
if (std::any_of(
logical_domain.begin(),
Expand All @@ -3159,8 +3168,16 @@ void validateLoopDomain(
reference.insert(
reference.end(), additional_ids.begin(), additional_ids.end());

std::vector<IterDomain*> covered_loop_domain;
covered_loop_domain.reserve(loop_domain.size() + no_reference_ids.size());
covered_loop_domain.insert(
covered_loop_domain.end(), loop_domain.begin(), loop_domain.end());
// no_reference_ids are also considered part of the loop domain
covered_loop_domain.insert(
covered_loop_domain.end(), no_reference_ids.begin(), no_reference_ids.end());

auto [redundant_ids, _, unreachable_reference_ids] =
ir_utils::compareDomainWithReference(loop_domain, reference);
ir_utils::compareDomainWithReference(covered_loop_domain, reference);

auto empty_or_broadcast = [](const auto& ids) {
return std::all_of(ids.begin(), ids.end(), [](IterDomain* id) {
Expand All @@ -3177,6 +3194,8 @@ void validateLoopDomain(
empty_or_broadcast(unreachable_reference_ids),
"Not all logical IDs are covered by loop domain. Loop: ",
toDelimitedString(loop_domain),
". no loop logical IDs: ",
toDelimitedString(no_reference_ids),
". Unreachable logical IDs: ",
toDelimitedString(unreachable_reference_ids));
}
Expand Down Expand Up @@ -3251,7 +3270,7 @@ TensorDomain::TensorDomain(
NVF_CHECK(
loop_domain_.empty() == logical_domain_.empty(),
"logical domain and loop domain can only be both empty or neither empty");
validateLoopDomain(logical_domain_, loop_domain_, additional_ids_);
validateLoopDomain(logical_domain_, loop_domain_, additional_ids_, no_loop_ids_);

// resetDomains initializes other member variables, required by clang-tidy
resetDomains();
Expand All @@ -3275,7 +3294,7 @@ TensorDomain::TensorDomain(
NVF_CHECK(
loop_domain_.empty() == logical_domain_.empty(),
"logical domain and loop domain can only be both empty or neither empty");
validateLoopDomain(logical_domain_, loop_domain_, additional_ids_);
validateLoopDomain(logical_domain_, loop_domain_, additional_ids_, no_loop_ids_);
if (!root_domain_.empty()) {
ir_utils::validateDomainEquivalence(
logical_domain_, root_domain_, additional_ids_);
Expand All @@ -3292,14 +3311,16 @@ TensorDomain::TensorDomain(
std::vector<IterDomain*> allocation_domain,
std::vector<IterDomain*> loop_domain,
std::vector<std::optional<bool>> contiguity,
std::vector<IterDomain*> additional_ids)
std::vector<IterDomain*> additional_ids,
std::vector<IterDomain*> no_loop_ids)
: Val(passkey, ValType::TensorDomain, DataType::Null),
root_domain_(std::move(root_domain)),
logical_domain_(std::move(logical_domain)),
allocation_domain_(std::move(allocation_domain)),
loop_domain_(std::move(loop_domain)),
initial_loop_domain_(loop_domain_),
additional_ids_(std::move(additional_ids)),
no_loop_ids_(std::move(no_loop_ids)),
contiguity_(
contiguity.empty() ? getContiguityFilledWith(maybeAllocation(), false)
: std::move(contiguity)) {
Expand All @@ -3308,11 +3329,11 @@ TensorDomain::TensorDomain(
NVF_CHECK(
loop_domain_.empty() == logical_domain_.empty(),
"logical domain and loop domain can only be both empty or neither empty");
validateLoopDomain(logical_domain_, loop_domain_, additional_ids_);
if (!root_domain_.empty()) {
ir_utils::validateDomainEquivalence(
logical_domain_, root_domain_, additional_ids_);
}
validateLoopDomain(logical_domain_, loop_domain_, additional_ids_, no_loop_ids_);
// if (!root_domain_.empty()) {
// ir_utils::validateDomainEquivalence(
// logical_domain_, root_domain_, additional_ids_);
// }
if (!allocation_domain_.empty()) {
ir_utils::validateDomainEquivalence(
logical_domain_, allocation_domain_, additional_ids_);
Expand All @@ -3330,6 +3351,7 @@ TensorDomain::TensorDomain(IrBuilderPasskey passkey, const TensorDomain* src)
loop_domain_(src->loop_domain_),
initial_loop_domain_(src->initial_loop_domain_),
additional_ids_(src->additional_ids_),
no_loop_ids_(src->no_loop_ids_),
no_bcast_domain_(src->no_bcast_domain_),
no_reduction_domain_(src->no_reduction_domain_),
contiguity_(src->contiguity_),
Expand All @@ -3343,6 +3365,7 @@ TensorDomain::TensorDomain(const TensorDomain* src, IrCloner* ir_cloner)
loop_domain_(ir_cloner->clone(src->loop_domain_)),
initial_loop_domain_(ir_cloner->clone(src->initial_loop_domain_)),
additional_ids_(ir_cloner->clone(src->additional_ids_)),
no_loop_ids_(ir_cloner->clone(src->no_loop_ids_)),
no_bcast_domain_(ir_cloner->clone(src->no_bcast_domain_)),
no_reduction_domain_(ir_cloner->clone(src->no_reduction_domain_)),
contiguity_(src->contiguity()),
Expand Down Expand Up @@ -3902,7 +3925,7 @@ std::pair<TensorDomain*, TensorDomain*> TensorDomain::rFactor(
}

void TensorDomain::setLoopDomain(std::vector<IterDomain*> new_loop_domain) {
validateLoopDomain(logical(), new_loop_domain, additionalIDs());
validateLoopDomain(logical(), new_loop_domain, additionalIDs(), noLoopIDs());
loop_domain_ = std::move(new_loop_domain);
initial_loop_domain_ = loop_domain_;
resetDomains();
Expand Down
40 changes: 34 additions & 6 deletions csrc/ir/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -781,6 +781,29 @@ bool isIndexSelectLookupTv(const TensorView* tv) {
return false;
}

bool isScatterSelfTv(const TensorView* tv) {
for (auto expr : tv->uses()) {
if (expr->isA<ScatterOp>()) {
auto idx_sel = expr->as<ScatterOp>();
if (idx_sel->selfTv() == tv) {
return true;
}
}
}
return false;
}
bool isScatterIndexTv(const TensorView* tv) {
for (auto expr : tv->uses()) {
if (expr->isA<ScatterOp>()) {
auto idx_sel = expr->as<ScatterOp>();
if (idx_sel->indexTv() == tv) {
return true;
}
}
}
return false;
}

bool isIndexSelectIndicesTv(const TensorView* tv) {
for (auto expr : tv->uses()) {
if (expr->isA<IndexSelectOp>()) {
Expand Down Expand Up @@ -1343,7 +1366,7 @@ bool hasUniformSiblings(Expr* expr) {
return !expr->isOneOf<SdpaFwdOp, SdpaBwdOp>();
}

bool hasRootToLoopLinearTransformations(const TensorView* tv) {
bool hasRootToLoopLinearTransformations(const TensorView* tv, const std::unordered_map<int, Val*>& override_index) {
auto root = tv->getMaybeRootDomain();
auto loop = tv->getLoopDomain();
std::vector<Val*> loop_val(loop.begin(), loop.end());
Expand All @@ -1352,11 +1375,16 @@ bool hasRootToLoopLinearTransformations(const TensorView* tv) {
std::unordered_set<Val*> all_ids_set(all_ids_vec.begin(), all_ids_vec.end());
auto alloc = tv->getMaybeAllocationDomain();
auto logical = tv->getLogicalDomain();
bool all_alloc_id_on_path = std::all_of(
alloc.begin(), alloc.end(), [&](Val* v) { return all_ids_set.count(v); });
bool all_logical_id_on_path =
std::all_of(logical.begin(), logical.end(), [&](Val* v) {
return all_ids_set.count(v);

bool all_alloc_id_on_path = std::ranges::all_of(
nvfuser::views::enumerate_view(alloc), [&](const auto& enumerator) {
const auto& [index, v] = enumerator;
return override_index.count(index) || all_ids_set.count(v);
});
bool all_logical_id_on_path = std::ranges::all_of(
nvfuser::views::enumerate_view(logical), [&](const auto& enumerator) {
const auto& [index, v] = enumerator;
return override_index.count(index) || all_ids_set.count(v);
});
return all_alloc_id_on_path && all_logical_id_on_path;
}
Expand Down
5 changes: 4 additions & 1 deletion csrc/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,9 @@ bool isIndexSelectLookupTv(const TensorView* tv);
// Check if the given tv is third argment of indexSelect(lookup, dim, indices)
bool isIndexSelectIndicesTv(const TensorView* tv);

bool isScatterSelfTv(const TensorView* tv);
bool isScatterIndexTv(const TensorView* tv);

bool isGatherLookupTv(const Val* tv);

std::string varName(const Val* val);
Expand Down Expand Up @@ -749,7 +752,7 @@ inline bool isMemorySharedAcross(
//! Check if the given tv has a root domain -> loop domain linear
//! transformation. This is a temporary check used to incrementally enable
//! IdModel. Eventually, this should be removed.
bool hasRootToLoopLinearTransformations(const TensorView* tv);
bool hasRootToLoopLinearTransformations(const TensorView* tv, const std::unordered_map<int, Val*>& override_index = {});

//! In addition to the above hasRootToLoopLinearTransformations, it
//! also checks the loop domain has any extra domain
Expand Down
Loading
Loading