Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
3bd640d
Allow definition of PrefixSumOp
jacobhinkle Apr 7, 2025
cde4b31
Update test to compile and run
jacobhinkle Apr 7, 2025
22ea549
Start implementing indexing for PrefixSumOp
jacobhinkle Apr 7, 2025
61291be
Merge remote-tracking branch 'origin/main' into jh/scan_wip
jacobhinkle Apr 8, 2025
eefe3d8
Enable TensorIndexer in test
jacobhinkle Apr 8, 2025
e8ae971
Update indexing
jacobhinkle Apr 8, 2025
1caf5eb
Add note about inlining in test
jacobhinkle Apr 8, 2025
fb24bfd
Fix evaluation of scalar discount factor
jacobhinkle Apr 8, 2025
bd5fd7b
Merge remote-tracking branch 'origin/main' into jh/scan_wip
jacobhinkle Apr 9, 2025
978d082
Add Concrete2D test
jacobhinkle Apr 9, 2025
bb44a1a
Fix discount factor
jacobhinkle Apr 9, 2025
6e7e001
Add broken online softmax test, with notes in comments
jacobhinkle Apr 9, 2025
f581d16
Generalize to ScanOp
jacobhinkle Apr 10, 2025
e376e5e
Generalize evaluation, some cleanups
jacobhinkle Apr 10, 2025
2b3eff4
Add LHS and RHS, use to improve OnlineSoftmax def
jacobhinkle Apr 10, 2025
3b90e99
Remove stale comment
jacobhinkle Apr 10, 2025
7cfa4b1
Add exclusive
jacobhinkle Apr 10, 2025
756508d
Evaluate exclusive scan
jacobhinkle Apr 11, 2025
d2c7bd6
Merge remote-tracking branch 'origin/main' into jh/scan_wip
jacobhinkle Apr 11, 2025
19a1e50
Fix evaluation. OnlineSoftmax test passes but is not inlined
jacobhinkle Apr 11, 2025
be81b65
Use computeWith to "inline" OnlineSoftmax into one loop
jacobhinkle Apr 11, 2025
01fa5b2
Cache inputs and outputs in OnlineSoftmax test
jacobhinkle Apr 11, 2025
bb81c4a
Add OnlineSoftmaxOuter test
jacobhinkle Apr 11, 2025
9cf09ec
Introduce IterType::Scan
jacobhinkle Apr 11, 2025
6de0839
Skip allocating scan dimensions
jacobhinkle Apr 11, 2025
4834913
Merge remote-tracking branch 'origin/main' into jh/scan_wip
jacobhinkle Apr 14, 2025
11ceeff
Remove stale comments from tests
jacobhinkle Apr 14, 2025
21fd7c1
First draft of validation
jacobhinkle Apr 14, 2025
f5574c2
Merge remote-tracking branch 'origin/main' into jh/scan_wip
jacobhinkle Apr 15, 2025
d5b95c7
Set iter type in non-exclusive scan
jacobhinkle Apr 16, 2025
cc845dc
Allocate scan dims when needed
jacobhinkle Apr 18, 2025
de35509
Add ComputeAtLogicalDomainMapBuilder::handle(ScanOp*)
jacobhinkle Apr 18, 2025
4369a16
Fix inlining in OnlineSoftmax tests
jacobhinkle Apr 18, 2025
db5798c
Use previously computed exclusive scan to compute inclusive
jacobhinkle Apr 18, 2025
aed6eb2
Switch OnlineSoftmax{,Outer} to symbolic sizes
jacobhinkle Apr 18, 2025
8e846b4
Fix outer test
jacobhinkle Apr 18, 2025
6c48103
Compute full softmax in test
jacobhinkle Apr 18, 2025
14e52bb
Introduce ScanResult
jacobhinkle Apr 21, 2025
d145b82
Enable reduction output, but not in indexing yet.
jacobhinkle Apr 21, 2025
4c655ec
Fix build
jacobhinkle Apr 21, 2025
1a01914
Add FlashAttentionNoMma test
jacobhinkle Apr 22, 2025
deaff1f
More fixes to FlashAttentionNoMma
jacobhinkle Apr 22, 2025
302015c
Merge remote-tracking branch 'origin/main' into jh/scan_wip
jacobhinkle Jul 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -706,6 +706,7 @@ list(APPEND JIT_TEST_SRCS
${NVFUSER_ROOT}/tests/cpp/test_rope.cpp
${NVFUSER_ROOT}/tests/cpp/test_runtime.cpp
${NVFUSER_ROOT}/tests/cpp/test_scalar_hoisting.cpp
${NVFUSER_ROOT}/tests/cpp/test_scan.cpp
${NVFUSER_ROOT}/tests/cpp/test_scatter.cpp
${NVFUSER_ROOT}/tests/cpp/test_sdpa_node.cpp
${NVFUSER_ROOT}/tests/cpp/test_segmentation.cpp
Expand Down
3 changes: 3 additions & 0 deletions csrc/device_lower/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,9 @@ void GpuLower::analysis(Fusion* fusion) {
validateReductions(fusion_);
dumpExprsIfEnabled(fusion_->exprs(), "validateReductions");

validateScans(fusion_);
dumpExprsIfEnabled(fusion_->exprs(), "validateScans");

// Compute thread predicates. Depends on parallel_dimension_map_
thread_pred_map_.build(fusion_);
dumpExprsIfEnabled(fusion_->exprs(), "build thread_pred_map_");
Expand Down
18 changes: 16 additions & 2 deletions csrc/device_lower/pass/allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ bool mayRequireAllocation(const TensorView* tv, IterDomain* id) {
// remain size one.
// - Reduction: Check the original ID, not the promotion, which may
// be a reduction ID even though the original ID is not a reduction

if (id->isScan()) {
// Allocate IterType::Scan IDs only if they are outputs or not computeWith.
// We know tv must not have computeAt past the scan id, so without
// computeWith, we know the expression won't be inlined and we'll need to
// allocate the scan id.
if (tv->isFusionOutput()) {
return true;
}
return !tv->hasComputeWith();
}

return !isPartitionedLoop(tv, id) && !isSizeOneDomain(id) &&
!id->isReduction() && !id->isStride();
}
Expand Down Expand Up @@ -1082,7 +1094,8 @@ class AllocationInserter : public kir::ExprMutator {
std::vector<IterDomain*> init_dims;
for (const auto axis_i : arange(info.alloc_pos, info.buffer->nDims())) {
if (info.buffer->axis(axis_i)->isReduction() ||
info.buffer->axis(axis_i)->isBroadcast()) {
info.buffer->axis(axis_i)->isBroadcast() ||
info.buffer->axis(axis_i)->isScan()) {
continue;
}
auto concrete_id =
Expand Down Expand Up @@ -1125,7 +1138,8 @@ class AllocationInserter : public kir::ExprMutator {
info.allocation_domains =
std::make_unique<std::vector<IterDomain*>>(alloc_ids);

if (alloc_dims.empty() && !info.buffer->domain()->noReductions().empty()) {
if (alloc_dims.empty() &&
!TensorDomain::noScans(info.buffer->domain()->noReductions()).empty()) {
alloc_dims.push_back(info.buffer->container()->oneVal());
}

Expand Down
82 changes: 82 additions & 0 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2812,4 +2812,86 @@ void IndexLowering::handle(const CatOp* cat) {
GpuLower::current()->propagateExprInfo(cat, expr);
}

void IndexLowering::handle(const ScanOp* scop) {
const auto in = lowerSrcIndex(scop->in(), scop->out());
const auto out = lowerDstIndex(scop->out());

// Find index for the scanDim IterDomain loop, so that we can modify the index
// by subtracting one.
IterDomain* scan_id =
TensorDomain::noReductions(scop->in()->getLogicalDomain())
.at((size_t)scop->scanDim());
int scan_id_alloc_pos = -1;
const std::vector<IterDomain*>& alloc_dom =
scop->in()->getMaybeAllocationDomain();
for (size_t alloc_pos : arange(alloc_dom.size())) {
if (alloc_dom.at(alloc_pos) == scan_id) {
scan_id_alloc_pos = alloc_pos;
break;
}
}
NVF_ERROR(
scan_id_alloc_pos != -1,
"Could not find scanned ID in allocation domain. ",
"Scan dimensions must not be merged or split during scheduling");
ValGraph& id_graph = GpuLower::current()->tensorIndexer().traversalGraph();
const ValGroup& scan_group = id_graph.toGroup(scan_id);
ForLoop* scan_loop = nullptr;
for (ForLoop* loop : for_loops_) {
if (id_graph.toGroup(loop->iter_domain()) == scan_group) {
scan_loop = loop;
break;
}
}
NVF_ERROR(
scan_loop != nullptr,
"Could not find for loop with scanned ID. ",
"Scan dimensions must not be merged or split during scheduling");
Val* scan_index = GpuLower::current()->tensorIndexer().getLoopIndex(
scan_loop->iter_domain(), for_loops_);
Val* lagged_index = sub(scan_index, GpuLower::current()->kernel()->oneVal());
// This gives us the previously computed value along the scanned dimension
Val* prev_sum_tensor = lowerDstIndex(
scop->out(), /*override_index=*/{{scan_id_alloc_pos, lagged_index}});

// Cache the tensors to scalars, to make building the where expression simpler
Val* next_val = IrBuilder::create<Val>(scop->in()->dtype());
IrBuilder::create<LoadStoreOp>(LoadStoreOpType::Set, next_val, in);

// Convert prev_sum to a scalar Val so that we don't need to allocate a new
// TensorView for it.
Val* prev_sum = IrBuilder::create<Val>(scop->out()->dtype());
IrBuilder::create<LoadStoreOp>(
LoadStoreOpType::Set, prev_sum, prev_sum_tensor);

prev_sum = where(gt(scan_index, scan_loop->start()), prev_sum, scop->init());

if (TensorView* exc = scop->outExclusive()) {
const auto exc_ti = lowerDstIndex(exc);
auto* save_exc_op =
IrBuilder::create<LoadStoreOp>(LoadStoreOpType::Set, exc_ti, prev_sum);
pushBack(save_exc_op);
prev_sum = exc_ti;
}

if (Val* f = scop->discountFactor()) {
Val* f_scalar = f;
if (auto* f_tv = dynamic_cast<TensorView*>(f)) {
Val* f_ti = lowerSrcIndex(f_tv, scop->out());
Val* prev_sum_mul = IrBuilder::create<Val>(f_tv->dtype());
IrBuilder::create<BinaryOp>(
BinaryOpType::Mul, prev_sum_mul, f_ti, prev_sum);
prev_sum = prev_sum_mul;
} else {
prev_sum = mul(f_scalar, prev_sum);
}
}

Expr* expr =
IrBuilder::create<BinaryOp>(scop->opType(), out, prev_sum, next_val);

pushBack(expr);
GpuLower::current()->propagateExprInfo(scop, expr);
}

} // namespace nvfuser
1 change: 1 addition & 0 deletions csrc/device_lower/pass/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class IndexLowering : private OptOutConstDispatch {
void handle(const PadOp*) final;
void handle(const SliceOp*) final;
void handle(const CatOp*) final;
void handle(const ScanOp*) final;

void handle(const kir::Asm*) final;
void handle(const ForLoop*) final;
Expand Down
1 change: 1 addition & 0 deletions csrc/device_lower/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ bool isTvOp(const Expr* expr) {
SdpaBwdOp,
EmbeddingFwdOp,
BroadcastOp,
ScanOp,
SqueezeOp,
ExpandOp,
RepeatOp,
Expand Down
72 changes: 72 additions & 0 deletions csrc/device_lower/validation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1319,4 +1319,76 @@ void validate1dTmaLoad(Fusion* fusion) {
}
}

// When we compute a scan we compute something like the following
//
// Array<float, 1, 1> T3;
// #pragma unroll
// for(nvfuser_index_t i1 = 0; i1 < 32; ++i1) {
// T3[0]
// = ((i1 > 0) ? (T3[0]) : 0.000000000e+00f)
// + (T2[i1]);
// }
//
// We need to validate that T3 is not inlined past the scan axis (the i1 loop
// in this case), because its lifetime must persist beyond the scan loop. Note
// that it is permissible to use `computeWith` as in this example to move the
// computed position inside the scan loop, alleviating the need to allocate an
// axis of size 32 in this case.
//
//
// Validate:
// 1. Outputs are inlined with all their uses past the scan dim
// 2. Discount factor and input are computed with this expression past the
// scan dim
// 3. Outputs are not inlined past the scan dimension, as this we require
// the scan outputs to be allocated outside the scan loop
void validateScans(Fusion* fusion) {
for (auto sop : ir_utils::getOpsOfType<ScanOp>(fusion)) {
auto* out = sop->out()->as<TensorView>();
TensorView* out_exclusive = sop->outExclusive();

// Find position of scan dim in loop domain
IterDomain* scan_id = out->getLogicalDomain().at((size_t)sop->scanDim());
int64_t scan_pos = -1L;
for (int64_t pos : arange(out->nDims())) {
if (out->axis(pos) == scan_id) {
scan_pos = pos;
break;
}
}
NVF_ERROR(
scan_pos != -1L,
"Could not find scan dimension ",
scan_id->toString(),
" in loop domain. Scan dimensions must not be scheduled with splits or "
"merges");

const auto check_uses = [&](TensorView* output) {
for (Expr* use : output->uses()) {
for (Val* outp : use->outputs()) {
if (auto* out_tv = dynamic_cast<TensorView*>(outp)) {
NVF_ERROR(
out_tv->getComputeWithPosition() >= scan_pos,
"Use of output, ",
use->toString(),
" must have all outputs inlined or computeWith to or past scan "
"position ",
scan_pos);
}
}
}
};

check_uses(out);

if (out_exclusive != nullptr) {
NVF_ERROR(out_exclusive->getComputeWithPosition() >= scan_pos);
NVF_ERROR(out_exclusive->getComputeAtPosition() <= scan_pos);
check_uses(out_exclusive);
}
// Must have allocated outside scan loop
NVF_ERROR(out->getComputeAtPosition() <= scan_pos);
}
}

} // namespace nvfuser
4 changes: 4 additions & 0 deletions csrc/device_lower/validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,8 @@ void validateReductions(Fusion* fusion);
//! divisible. This is similar to vectorization, where we don't have an extra
//! else branch to load the tailing elements.
void validate1dTmaLoad(Fusion* fusion);

//! Validate scheduling of ScanOp inputs and outputs
void validateScans(Fusion* fusion);

} // namespace nvfuser
2 changes: 1 addition & 1 deletion csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,12 @@ class Val;
f(ViewOp); \
f(CatOp); \
f(PadOp); \
f(ScanOp); \
f(SliceOp); \
f(Split); \
f(ArgsortOp); \
f(GroupedMmaOp); \
f(TopKOp); \
f(ScanOp); \
f(Merge); \
f(Swizzle); \
f(Swizzle2D); \
Expand Down
1 change: 1 addition & 0 deletions csrc/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ const PolymorphicValue& ExpressionEvaluator::evaluate(
if (!maybe_concrete_value.get().hasValue()) {
if (auto def = value->definition()) {
auto outputs = def->evaluate(*this, known_values);
NVF_ERROR(outputs.size() == def->outputs().size());
for (auto i : arange(def->outputs().size())) {
known_values[def->output(i)] = std::move(outputs[i]);
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/id_model/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ std::pair<std::vector<Val*>, std::vector<Val*>> TensorIndexer::
auto index_info = computeIndex(
expr, indexed_ids, for_loops, isSharedMemoryTvForLdStMatrix(tv, expr));
for (const auto& [indexed_id, index] : override_index) {
index_info.index_map.emplace(traversalGraph().toGroup(indexed_id), index);
index_info.index_map[traversalGraph().toGroup(indexed_id)] = index;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is a bug since emplace will simply fail and return false if there is a pre-existing entry, whereas we need to update the index here instead.

}
const auto& index_map = index_info.index_map;
auto replacement_map = getIndexReplacementMap(
Expand Down
7 changes: 6 additions & 1 deletion csrc/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2333,8 +2333,13 @@ kir::TensorIndex* Index::getConsumerIndex(
GpuLower::current()->idModelOptions().consumerIndex() ||
GpuLower::current()->tmemInfo().hasTMemTensor()) {
NVF_ERROR(rotated_loops.empty(), "Loop rotation is not supported");
std::unordered_map<IterDomain*, Val*> override_index_ids;
for (auto& [pos, idx] : override_index) {
override_index_ids.emplace(
consumer->getMaybeAllocationDomain().at((size_t)pos), idx);
}
index = GpuLower::current()->tensorIndexer().getLinearIndex(
consumer, consumer->definition(), loops);
consumer, consumer->definition(), loops, override_index_ids);
Comment on lines +2336 to +2342
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is necessary for getConsumerIndex to respect override_index.

if (generate_pointer) {
auto address_offset = index;
if (consumer->getMemoryType() == MemoryType::Shared) {
Expand Down
5 changes: 5 additions & 0 deletions csrc/ir/internal_base_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,10 @@ class IterDomain : public Val {
return getIterType() == IterType::Symbolic;
}

bool isScan() const {
return getIterType() == IterType::Scan;
}

bool isGatherScatter() const {
return getIterType() == IterType::GatherScatter;
}
Expand Down Expand Up @@ -724,6 +728,7 @@ class TensorDomain : public Val {
const std::unordered_map<int64_t, int64_t>& old2new);

static std::vector<IterDomain*> noReductions(const std::vector<IterDomain*>&);
static std::vector<IterDomain*> noScans(const std::vector<IterDomain*>&);
static std::vector<IterDomain*> noBroadcasts(const std::vector<IterDomain*>&);
static std::vector<IterDomain*> noDevices(const std::vector<IterDomain*>&);

Expand Down
53 changes: 43 additions & 10 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -3167,12 +3167,17 @@ class ScanOp : public Expr {
public:
using Expr::Expr;

// NOTE: We translate these nodes to other nodes during indexing, so we should
// never expect to receive TensorIndex arguments here
ScanOp(
IrBuilderPasskey,
BinaryOpType op_type,
TensorView* output_inclusive,
TensorView* output_exclusive,
TensorView* output_reduction,
TensorView* input,
Val* discount_factor,
Val* init,
Val* out,
Val* in,
int64_t dim);

NVFUSER_DECLARE_CLONE_AND_CREATE
Expand All @@ -3185,24 +3190,52 @@ class ScanOp : public Expr {
std::string toInlineString(int indent_size = 0) const override;

//! Returns the inclusive scan output
Val* out() const {
return output(0);
TensorView* out() const {
return output(0)->as<TensorView>();
}

Val* in() const {
return input(0);
//! Returns the exclusive scan output if available, otherwise nullptr
TensorView* outExclusive() const {
if (!hasExclusive()) {
return nullptr;
}
return output(1)->as<TensorView>();
}

Val* init() const {
return attributeVal(0);
//! Returns the exclusive scan output if available, otherwise nullptr
TensorView* outReduction() const {
if (!hasReduction()) {
return nullptr;
}
return output(hasExclusive() ? 2 : 1)->as<TensorView>();
}

TensorView* in() const {
return input(0)->as<TensorView>();
}

Val* discountFactor() const {
return inputs().size() > 1 ? input(1) : nullptr;
}

BinaryOpType opType() const {
return attribute<BinaryOpType>(1);
return attribute<BinaryOpType>(0);
}

int64_t scanDim() const {
return attribute<int64_t>(2);
return attribute<int64_t>(1);
}

bool hasExclusive() const {
return attribute<bool>(2);
}

bool hasReduction() const {
return attribute<bool>(3);
}

Val* init() const {
return attributeVal(4);
}

std::vector<PolymorphicValue> evaluate(
Expand Down
Loading