diff --git a/CMakeLists.txt b/CMakeLists.txt index 184f25afbac..48cab97f3f6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -706,6 +706,7 @@ list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/test_rope.cpp ${NVFUSER_ROOT}/tests/cpp/test_runtime.cpp ${NVFUSER_ROOT}/tests/cpp/test_scalar_hoisting.cpp + ${NVFUSER_ROOT}/tests/cpp/test_scan.cpp ${NVFUSER_ROOT}/tests/cpp/test_scatter.cpp ${NVFUSER_ROOT}/tests/cpp/test_sdpa_node.cpp ${NVFUSER_ROOT}/tests/cpp/test_segmentation.cpp diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 8cfc0e416ad..94fe2a2795a 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -563,6 +563,9 @@ void GpuLower::analysis(Fusion* fusion) { validateReductions(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "validateReductions"); + validateScans(fusion_); + dumpExprsIfEnabled(fusion_->exprs(), "validateScans"); + // Compute thread predicates. Depends on parallel_dimension_map_ thread_pred_map_.build(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "build thread_pred_map_"); diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index f65824db83d..1a0cc61c77d 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -50,6 +50,18 @@ bool mayRequireAllocation(const TensorView* tv, IterDomain* id) { // remain size one. // - Reduction: Check the original ID, not the promotion, which may // be a reduction ID even though the original ID is not a reduction + + if (id->isScan()) { + // Allocate IterType::Scan IDs only if they are outputs or not computeWith. + // We know tv must not have computeAt past the scan id, so without + // computeWith, we know the expression won't be inlined and we'll need to + // allocate the scan id. + if (tv->isFusionOutput()) { + return true; + } + return !tv->hasComputeWith(); + } + return !isPartitionedLoop(tv, id) && !isSizeOneDomain(id) && !id->isReduction() && !id->isStride(); } @@ -1082,7 +1094,8 @@ class AllocationInserter : public kir::ExprMutator { std::vector init_dims; for (const auto axis_i : arange(info.alloc_pos, info.buffer->nDims())) { if (info.buffer->axis(axis_i)->isReduction() || - info.buffer->axis(axis_i)->isBroadcast()) { + info.buffer->axis(axis_i)->isBroadcast() || + info.buffer->axis(axis_i)->isScan()) { continue; } auto concrete_id = @@ -1125,7 +1138,8 @@ class AllocationInserter : public kir::ExprMutator { info.allocation_domains = std::make_unique>(alloc_ids); - if (alloc_dims.empty() && !info.buffer->domain()->noReductions().empty()) { + if (alloc_dims.empty() && + !TensorDomain::noScans(info.buffer->domain()->noReductions()).empty()) { alloc_dims.push_back(info.buffer->container()->oneVal()); } diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 8a9f2de5d51..d520f2395a3 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -2812,4 +2812,86 @@ void IndexLowering::handle(const CatOp* cat) { GpuLower::current()->propagateExprInfo(cat, expr); } +void IndexLowering::handle(const ScanOp* scop) { + const auto in = lowerSrcIndex(scop->in(), scop->out()); + const auto out = lowerDstIndex(scop->out()); + + // Find index for the scanDim IterDomain loop, so that we can modify the index + // by subtracting one. + IterDomain* scan_id = + TensorDomain::noReductions(scop->in()->getLogicalDomain()) + .at((size_t)scop->scanDim()); + int scan_id_alloc_pos = -1; + const std::vector& alloc_dom = + scop->in()->getMaybeAllocationDomain(); + for (size_t alloc_pos : arange(alloc_dom.size())) { + if (alloc_dom.at(alloc_pos) == scan_id) { + scan_id_alloc_pos = alloc_pos; + break; + } + } + NVF_ERROR( + scan_id_alloc_pos != -1, + "Could not find scanned ID in allocation domain. ", + "Scan dimensions must not be merged or split during scheduling"); + ValGraph& id_graph = GpuLower::current()->tensorIndexer().traversalGraph(); + const ValGroup& scan_group = id_graph.toGroup(scan_id); + ForLoop* scan_loop = nullptr; + for (ForLoop* loop : for_loops_) { + if (id_graph.toGroup(loop->iter_domain()) == scan_group) { + scan_loop = loop; + break; + } + } + NVF_ERROR( + scan_loop != nullptr, + "Could not find for loop with scanned ID. ", + "Scan dimensions must not be merged or split during scheduling"); + Val* scan_index = GpuLower::current()->tensorIndexer().getLoopIndex( + scan_loop->iter_domain(), for_loops_); + Val* lagged_index = sub(scan_index, GpuLower::current()->kernel()->oneVal()); + // This gives us the previously computed value along the scanned dimension + Val* prev_sum_tensor = lowerDstIndex( + scop->out(), /*override_index=*/{{scan_id_alloc_pos, lagged_index}}); + + // Cache the tensors to scalars, to make building the where expression simpler + Val* next_val = IrBuilder::create(scop->in()->dtype()); + IrBuilder::create(LoadStoreOpType::Set, next_val, in); + + // Convert prev_sum to a scalar Val so that we don't need to allocate a new + // TensorView for it. + Val* prev_sum = IrBuilder::create(scop->out()->dtype()); + IrBuilder::create( + LoadStoreOpType::Set, prev_sum, prev_sum_tensor); + + prev_sum = where(gt(scan_index, scan_loop->start()), prev_sum, scop->init()); + + if (TensorView* exc = scop->outExclusive()) { + const auto exc_ti = lowerDstIndex(exc); + auto* save_exc_op = + IrBuilder::create(LoadStoreOpType::Set, exc_ti, prev_sum); + pushBack(save_exc_op); + prev_sum = exc_ti; + } + + if (Val* f = scop->discountFactor()) { + Val* f_scalar = f; + if (auto* f_tv = dynamic_cast(f)) { + Val* f_ti = lowerSrcIndex(f_tv, scop->out()); + Val* prev_sum_mul = IrBuilder::create(f_tv->dtype()); + IrBuilder::create( + BinaryOpType::Mul, prev_sum_mul, f_ti, prev_sum); + prev_sum = prev_sum_mul; + } else { + prev_sum = mul(f_scalar, prev_sum); + } + } + + Expr* expr = + IrBuilder::create(scop->opType(), out, prev_sum, next_val); + + pushBack(expr); + GpuLower::current()->propagateExprInfo(scop, expr); +} + } // namespace nvfuser diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index 1712c29df98..b1372993c8e 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -68,6 +68,7 @@ class IndexLowering : private OptOutConstDispatch { void handle(const PadOp*) final; void handle(const SliceOp*) final; void handle(const CatOp*) final; + void handle(const ScanOp*) final; void handle(const kir::Asm*) final; void handle(const ForLoop*) final; diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index d9469efff93..fe120244800 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -114,6 +114,7 @@ bool isTvOp(const Expr* expr) { SdpaBwdOp, EmbeddingFwdOp, BroadcastOp, + ScanOp, SqueezeOp, ExpandOp, RepeatOp, diff --git a/csrc/device_lower/validation.cpp b/csrc/device_lower/validation.cpp index 9a21b4ac262..663456f7d24 100644 --- a/csrc/device_lower/validation.cpp +++ b/csrc/device_lower/validation.cpp @@ -1319,4 +1319,76 @@ void validate1dTmaLoad(Fusion* fusion) { } } +// When we compute a scan we compute something like the following +// +// Array T3; +// #pragma unroll +// for(nvfuser_index_t i1 = 0; i1 < 32; ++i1) { +// T3[0] +// = ((i1 > 0) ? (T3[0]) : 0.000000000e+00f) +// + (T2[i1]); +// } +// +// We need to validate that T3 is not inlined past the scan axis (the i1 loop +// in this case), because its lifetime must persist beyond the scan loop. Note +// that it is permissible to use `computeWith` as in this example to move the +// computed position inside the scan loop, alleviating the need to allocate an +// axis of size 32 in this case. +// +// +// Validate: +// 1. Outputs are inlined with all their uses past the scan dim +// 2. Discount factor and input are computed with this expression past the +// scan dim +// 3. Outputs are not inlined past the scan dimension, as this we require +// the scan outputs to be allocated outside the scan loop +void validateScans(Fusion* fusion) { + for (auto sop : ir_utils::getOpsOfType(fusion)) { + auto* out = sop->out()->as(); + TensorView* out_exclusive = sop->outExclusive(); + + // Find position of scan dim in loop domain + IterDomain* scan_id = out->getLogicalDomain().at((size_t)sop->scanDim()); + int64_t scan_pos = -1L; + for (int64_t pos : arange(out->nDims())) { + if (out->axis(pos) == scan_id) { + scan_pos = pos; + break; + } + } + NVF_ERROR( + scan_pos != -1L, + "Could not find scan dimension ", + scan_id->toString(), + " in loop domain. Scan dimensions must not be scheduled with splits or " + "merges"); + + const auto check_uses = [&](TensorView* output) { + for (Expr* use : output->uses()) { + for (Val* outp : use->outputs()) { + if (auto* out_tv = dynamic_cast(outp)) { + NVF_ERROR( + out_tv->getComputeWithPosition() >= scan_pos, + "Use of output, ", + use->toString(), + " must have all outputs inlined or computeWith to or past scan " + "position ", + scan_pos); + } + } + } + }; + + check_uses(out); + + if (out_exclusive != nullptr) { + NVF_ERROR(out_exclusive->getComputeWithPosition() >= scan_pos); + NVF_ERROR(out_exclusive->getComputeAtPosition() <= scan_pos); + check_uses(out_exclusive); + } + // Must have allocated outside scan loop + NVF_ERROR(out->getComputeAtPosition() <= scan_pos); + } +} + } // namespace nvfuser diff --git a/csrc/device_lower/validation.h b/csrc/device_lower/validation.h index 0b0a8545139..5cf8ded3f06 100644 --- a/csrc/device_lower/validation.h +++ b/csrc/device_lower/validation.h @@ -75,4 +75,8 @@ void validateReductions(Fusion* fusion); //! divisible. This is similar to vectorization, where we don't have an extra //! else branch to load the tailing elements. void validate1dTmaLoad(Fusion* fusion); + +//! Validate scheduling of ScanOp inputs and outputs +void validateScans(Fusion* fusion); + } // namespace nvfuser diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 32f11cbb9a1..e7151cab1ec 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -102,12 +102,12 @@ class Val; f(ViewOp); \ f(CatOp); \ f(PadOp); \ + f(ScanOp); \ f(SliceOp); \ f(Split); \ f(ArgsortOp); \ f(GroupedMmaOp); \ f(TopKOp); \ - f(ScanOp); \ f(Merge); \ f(Swizzle); \ f(Swizzle2D); \ diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index 882c807b730..185316a09b5 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -284,6 +284,7 @@ const PolymorphicValue& ExpressionEvaluator::evaluate( if (!maybe_concrete_value.get().hasValue()) { if (auto def = value->definition()) { auto outputs = def->evaluate(*this, known_values); + NVF_ERROR(outputs.size() == def->outputs().size()); for (auto i : arange(def->outputs().size())) { known_values[def->output(i)] = std::move(outputs[i]); } diff --git a/csrc/id_model/indexing.cpp b/csrc/id_model/indexing.cpp index 78ddc2f49e1..11edbba3be3 100644 --- a/csrc/id_model/indexing.cpp +++ b/csrc/id_model/indexing.cpp @@ -971,7 +971,7 @@ std::pair, std::vector> TensorIndexer:: auto index_info = computeIndex( expr, indexed_ids, for_loops, isSharedMemoryTvForLdStMatrix(tv, expr)); for (const auto& [indexed_id, index] : override_index) { - index_info.index_map.emplace(traversalGraph().toGroup(indexed_id), index); + index_info.index_map[traversalGraph().toGroup(indexed_id)] = index; } const auto& index_map = index_info.index_map; auto replacement_map = getIndexReplacementMap( diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index 91a3bcd2fd1..685ca274523 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -2333,8 +2333,13 @@ kir::TensorIndex* Index::getConsumerIndex( GpuLower::current()->idModelOptions().consumerIndex() || GpuLower::current()->tmemInfo().hasTMemTensor()) { NVF_ERROR(rotated_loops.empty(), "Loop rotation is not supported"); + std::unordered_map override_index_ids; + for (auto& [pos, idx] : override_index) { + override_index_ids.emplace( + consumer->getMaybeAllocationDomain().at((size_t)pos), idx); + } index = GpuLower::current()->tensorIndexer().getLinearIndex( - consumer, consumer->definition(), loops); + consumer, consumer->definition(), loops, override_index_ids); if (generate_pointer) { auto address_offset = index; if (consumer->getMemoryType() == MemoryType::Shared) { diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index b96b226443d..657234aa555 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -189,6 +189,10 @@ class IterDomain : public Val { return getIterType() == IterType::Symbolic; } + bool isScan() const { + return getIterType() == IterType::Scan; + } + bool isGatherScatter() const { return getIterType() == IterType::GatherScatter; } @@ -724,6 +728,7 @@ class TensorDomain : public Val { const std::unordered_map& old2new); static std::vector noReductions(const std::vector&); + static std::vector noScans(const std::vector&); static std::vector noBroadcasts(const std::vector&); static std::vector noDevices(const std::vector&); diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 1c66a0f8e77..7a34f8e11e3 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -3167,12 +3167,17 @@ class ScanOp : public Expr { public: using Expr::Expr; + // NOTE: We translate these nodes to other nodes during indexing, so we should + // never expect to receive TensorIndex arguments here ScanOp( IrBuilderPasskey, BinaryOpType op_type, + TensorView* output_inclusive, + TensorView* output_exclusive, + TensorView* output_reduction, + TensorView* input, + Val* discount_factor, Val* init, - Val* out, - Val* in, int64_t dim); NVFUSER_DECLARE_CLONE_AND_CREATE @@ -3185,24 +3190,52 @@ class ScanOp : public Expr { std::string toInlineString(int indent_size = 0) const override; //! Returns the inclusive scan output - Val* out() const { - return output(0); + TensorView* out() const { + return output(0)->as(); } - Val* in() const { - return input(0); + //! Returns the exclusive scan output if available, otherwise nullptr + TensorView* outExclusive() const { + if (!hasExclusive()) { + return nullptr; + } + return output(1)->as(); } - Val* init() const { - return attributeVal(0); + //! Returns the exclusive scan output if available, otherwise nullptr + TensorView* outReduction() const { + if (!hasReduction()) { + return nullptr; + } + return output(hasExclusive() ? 2 : 1)->as(); + } + + TensorView* in() const { + return input(0)->as(); + } + + Val* discountFactor() const { + return inputs().size() > 1 ? input(1) : nullptr; } BinaryOpType opType() const { - return attribute(1); + return attribute(0); } int64_t scanDim() const { - return attribute(2); + return attribute(1); + } + + bool hasExclusive() const { + return attribute(2); + } + + bool hasReduction() const { + return attribute(3); + } + + Val* init() const { + return attributeVal(4); } std::vector evaluate( diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index f2aef8fd6e2..b2e4fd5654a 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -29,6 +29,7 @@ #include #include +#include #include #include #include @@ -1612,6 +1613,24 @@ std::vector ReductionOp::evaluate( case BinaryOpType::Min: return {at::amin(input, reduction_axes)}; break; + case BinaryOpType::LHS: { + std::vector slice_loc( + input.dim(), torch::indexing::Ellipsis); + for (int64_t ax : reduction_axes) { + slice_loc.at(ax) = 0L; + } + return {input.index(slice_loc)}; + break; + } + case BinaryOpType::RHS: { + std::vector slice_loc( + input.dim(), torch::indexing::Ellipsis); + for (int64_t ax : reduction_axes) { + slice_loc.at(ax) = -1L; + } + return {input.index(slice_loc)}; + break; + } default: NVF_CHECK( false, @@ -3836,6 +3855,17 @@ std::vector TensorDomain::noReductions( return noReductionDomain; } +std::vector TensorDomain::noScans( + const std::vector& td) { + std::vector noReductionDomain; + std::copy_if( + td.begin(), + td.end(), + std::back_inserter(noReductionDomain), + [](IterDomain* id) { return !id->isReduction() && !id->isStride(); }); + return noReductionDomain; +} + std::vector TensorDomain::noBroadcasts( const std::vector& td) { std::vector noBroadcastDomain; @@ -6062,25 +6092,51 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(GroupedMmaOp) ScanOp::ScanOp( IrBuilderPasskey passkey, BinaryOpType op_type, + TensorView* output_inclusive, + TensorView* output_exclusive, + TensorView* output_reduction, + TensorView* input, + Val* discount_factor, Val* init, - Val* out, - Val* in, int64_t dim) : Expr(passkey) { - addOutput(out); - addInput(in); - addAttribute(init); + addOutput(output_inclusive); + if (output_exclusive != nullptr) { + addOutput(output_exclusive); + } + if (output_reduction != nullptr) { + addOutput(output_reduction); + } + addInput(input); + if (discount_factor != nullptr) { + addInput(discount_factor); + } addDataAttribute(op_type); addDataAttribute(dim); + addDataAttribute(output_exclusive != nullptr); + addDataAttribute(output_reduction != nullptr); + addAttribute(init); } std::string ScanOp::toString(int indent_size) const { std::stringstream ss; indent(ss, indent_size) << out()->toString(); + if (hasExclusive()) { + ss << ",\n"; + indent(ss, indent_size) << outExclusive()->toString(); + } + if (hasReduction()) { + ss << ",\n"; + indent(ss, indent_size) << outReduction()->toString(); + } ss << "\n"; - indent(ss, indent_size + 1) << " = scan(" << in()->toString() << ",\n"; + indent(ss, indent_size + 1) << " = scan(" << opType() << ",\n"; + indent(ss, indent_size + 1) << " " << in()->toString() << ",\n"; indent(ss, indent_size + 1) << " dim=" << scanDim() << ",\n"; - indent(ss, indent_size + 1) << " op_type=" << opType() << ",\n"; + if (discountFactor() != nullptr) { + indent(ss, indent_size + 1) + << " discount_factor=" << discountFactor()->toString() << ",\n"; + } indent(ss, indent_size + 1) << " init=" << init()->toInlineString() << ")\n"; return ss.str(); @@ -6095,27 +6151,121 @@ std::vector ScanOp::evaluate( const std::vector& inputs) const { auto input = inputs.at(0).as(); - NVF_ERROR(inputs.size() == 1); + if (discountFactor() == nullptr) { + NVF_ERROR(inputs.size() == 1); - at::Tensor out; - switch (opType()) { - case BinaryOpType::Add: - out = at::cumsum(input, scanDim()); - break; - case BinaryOpType::Max: - out = std::get<0>(at::cummax(input, scanDim())); - break; - case BinaryOpType::Min: - out = std::get<0>(at::cummin(input, scanDim())); - break; - case BinaryOpType::Mul: - out = at::cumprod(input, scanDim()); - break; - default: - NVF_THROW("Unhandled opType() ", opType()); - } + at::Tensor out_inclusive; + float identity; + switch (opType()) { + case BinaryOpType::Add: + out_inclusive = at::cumsum(input, scanDim()); + identity = 0.0; + break; + case BinaryOpType::Max: + out_inclusive = std::get<0>(at::cummax(input, scanDim())); + identity = -std::numeric_limits::infinity(); + break; + case BinaryOpType::Min: + out_inclusive = std::get<0>(at::cummin(input, scanDim())); + identity = std::numeric_limits::infinity(); + break; + case BinaryOpType::Mul: + out_inclusive = at::cumprod(input, scanDim()); + identity = 1.0; + break; + default: + NVF_THROW("Unhandled opType() ", opType()); + } + if (outExclusive() == nullptr) { + return {out_inclusive}; + } else { + std::vector pad_widths( + 2 * ((int64_t)input.dim() - scanDim()), 0L); + pad_widths[pad_widths.size() - 2] = 1L; + pad_widths[pad_widths.size() - 1] = -1L; + at::Tensor out_exclusive = + at::pad(out_inclusive, pad_widths, "constant", identity); + return {out_inclusive, out_exclusive}; + } + } else { + // auto discount_factor = inputs.at(1).as(); + NVF_ERROR(inputs.size() == 2); + + const PolymorphicValue& df = inputs.at(1); + + NVF_ERROR(df.hasValue()); + + at::Tensor out = at::zeros_like(input); + at::Tensor out_exclusive; + + if (outExclusive() != nullptr) { + out_exclusive = at::ones_like(out); + PolymorphicValue init_pv = init()->value(); + if (init_pv.is()) { + out_exclusive *= init_pv.as(); + } else if (init_pv.is()) { + out_exclusive *= init_pv.as(); + } else { + NVF_THROW("Unsupported type for evaluation: ", init()->dtype()); + } + } + + at::Tensor cur; + std::vector slice_pos; + slice_pos.reserve((size_t)input.dim()); + for ([[maybe_unused]] int64_t i : arange(input.dim())) { + slice_pos.push_back(at::indexing::Slice()); + } + for (int64_t i : arange(input.size(scanDim()))) { + slice_pos.at((size_t)scanDim()) = at::indexing::TensorIndex((int64_t)i); - return {out}; + at::Tensor next_slice = input.index(slice_pos); + + if (i == 0) { + cur = next_slice; + } else { + if (outExclusive() != nullptr) { + out_exclusive.index(slice_pos).copy_(cur); + } + + if (df.is()) { + cur *= df.as(); + } else if (df.is()) { + cur *= df.as(); + } else if (df.is()) { + // TODO: handle case where scanDim() is broadcast in discount factor + cur *= df.as().index(slice_pos); + } else { + NVF_THROW("Unhandled discount factor type"); + } + + switch (opType()) { + case BinaryOpType::Add: + cur += next_slice; + break; + case BinaryOpType::Max: + cur = cur.maximum(next_slice); + break; + case BinaryOpType::Min: + cur = cur.minimum(next_slice); + break; + case BinaryOpType::Mul: + cur *= next_slice; + break; + default: + NVF_THROW("Unhandled opType() ", opType()); + } + } + + out.index(slice_pos).copy_(cur); + } + + if (outExclusive() == nullptr) { + return {out}; + } else { + return {out, out_exclusive}; + } + } } NVFUSER_DEFINE_CLONE_AND_CREATE(ScanOp) diff --git a/csrc/logical_domain_map.cpp b/csrc/logical_domain_map.cpp index 19e87ec45db..c765e30a8c5 100644 --- a/csrc/logical_domain_map.cpp +++ b/csrc/logical_domain_map.cpp @@ -1285,7 +1285,7 @@ void ComputeAtLogicalDomainMapBuilder::mapPointwiseLikeOp(Expr* expr) { if (expr->outputs().size() > 1) { NVF_ERROR( expr->isA() || expr->isA() || - expr->isA(), + expr->isA() || expr->isA(), "Unknown multi-output Expr type ", expr->getOpString(), " is found"); diff --git a/csrc/logical_domain_map.h b/csrc/logical_domain_map.h index 8a0fbc5e2fd..928f63f5acd 100644 --- a/csrc/logical_domain_map.h +++ b/csrc/logical_domain_map.h @@ -513,6 +513,10 @@ class ComputeAtLogicalDomainMapBuilder : private BackwardVisitor { mapPointwiseLikeOp(op); } + void handle(ScanOp* op) override { + mapPointwiseLikeOp(op); + } + void handle(SliceOp* op) override { mapPointwiseLikeOp(op); } diff --git a/csrc/ops/arith.cpp b/csrc/ops/arith.cpp index 05efcbde208..64955e4bd1c 100644 --- a/csrc/ops/arith.cpp +++ b/csrc/ops/arith.cpp @@ -2554,50 +2554,94 @@ TopKResult topk( out_values->as(), out_indices->as()); } -TensorView* scan( - TensorView* in_tv, +ScanResult scan( + TensorView* tv, int64_t dim, BinaryOpType op_type, - Val* init) { + Val* init, + Val* discount_factor, + bool return_exclusive, + bool return_reduction) { const std::vector logical_dom = - TensorDomain::noReductions(in_tv->getLogicalDomain()); + TensorDomain::noReductions(tv->getLogicalDomain()); dim = wrapDim(dim, (int64_t)logical_dom.size()); IterDomain* scan_id = logical_dom.at((size_t)dim); + if (init == nullptr) { + init = ops::binOpIdentity(op_type, tv->dtype()); + NVF_ERROR(init != nullptr); + } + // Special case: scanning along broadcast dimension is no-op // Assumes init is identity for op_type if (scan_id->isBroadcast()) { - NVF_ERROR( - !scan_id->hasExpandedExtent(), - "Closed-form scan of expanded dimension is not yet implemented"); - return set(in_tv); + if (scan_id->hasExpandedExtent()) { + NVF_THROW( + "Closed-form scan of expanded dimension is not yet implemented"); + } + // Exclusive scan is just the init val + return {set(tv), mul(init, ones_like(tv))}; } - DataType dtype = in_tv->dtype(); - auto new_dom = ops::newOutputDomain({in_tv}); + std::vector new_dom; + DataType dtype = tv->dtype(); + if (discount_factor == nullptr) { + new_dom = ops::newOutputDomain({tv}); + } else { + new_dom = ops::newOutputDomain({tv, discount_factor}); + dtype = promoteType(tv->dtype(), discount_factor->dtype()); + tv = maybeCastOp(dtype, tv); + discount_factor = maybeCastOp(dtype, discount_factor); + } + new_dom.at((size_t)dim) = IterDomainBuilder(new_dom.at((size_t)dim)) + .iter_type(IterType::Scan) + .build(); auto* td = IrBuilder::create( new_dom, TensorDomain::getContiguityFilledWith(new_dom, true)); - auto out_tv = IrBuilder::create(td, in_tv->dtype()); - if (init == nullptr) { - init = ops::binOpIdentity(op_type, dtype); - NVF_ERROR(init != nullptr); + ScanResult result; + + result.inclusive = IrBuilder::create(td, tv->dtype()); + + if (return_exclusive) { + result.exclusive = ops::newOutputTV({result.inclusive}, tv->dtype()); + } + + if (return_reduction) { + std::vector red_dom = ops::newOutputDomain({tv}); + red_dom.at((size_t)dim) = IterDomainBuilder(red_dom.at((size_t)dim)) + .iter_type(IterType::Reduction) + .build(); + auto* red_td = IrBuilder::create( + red_dom, TensorDomain::getContiguityFilledWith(red_dom, true)); + + result.reduction = IrBuilder::create(red_td, tv->dtype()); } IrBuilder::createInContainer( - in_tv->container(), op_type, init, out_tv, in_tv, dim); + tv->container(), + op_type, + result.inclusive, + result.exclusive, + result.reduction, + tv, + discount_factor, + init, + dim); - return out_tv; + return result; } -TensorView* prefixSum(TensorView* tv, int64_t dim) { +TensorView* prefixSum(TensorView* tv, int64_t dim, Val* discount_factor) { return scan( - tv, - dim, - BinaryOpType::Add, - /*init=*/tv->fusion()->zeroVal(tv->dtype())); + tv, + dim, + BinaryOpType::Add, + /*init=*/tv->fusion()->zeroVal(tv->dtype()), + discount_factor) + .inclusive; } } // namespace nvfuser diff --git a/csrc/ops/arith.h b/csrc/ops/arith.h index 6db3e66d8c1..d912aa1ae8a 100644 --- a/csrc/ops/arith.h +++ b/csrc/ops/arith.h @@ -804,25 +804,51 @@ NVF_API TopKResult topk( bool sorted = false, bool maybe_symbolic = true); +struct ScanResult { + TensorView* inclusive = nullptr; + TensorView* exclusive = nullptr; + TensorView* reduction = nullptr; +}; + //! Computes an inclusive scan of a tensor in a single dimension. //! -//! Given a 1D input tensor x, this computes the output +//! Given a 1D input tensor x and discount factor f, this computes the output //! recursively via //! -//! y = scan(x, 0, Add, zeroVal()) +//! y = scan(x, 0, Add, zeroVal(), f) +//! +//! y[0] = x[0] +//! y[i] = f*y[i-1] + x[i] for 0 < i < n +//! +//! Note that the discount factor can also be a TensorView, in which case we +//! compute +//! +//! y[0] = x[0] +//! y[i] = f[i]*y[i-1] + x[i] for 0 < i < n +//! +//! Notice that the first discount factor in the scanned dimension is ignored. +//! +//! If `discount_factor` is null, then we compute the regular prefix sum: //! //! y[0] = x[0] //! y[i] = y[i-1] + x[i] for 0 < i < n //! //! If the dimension being scanned is an expanded broadcast, we throw an error. -NVF_API TensorView* scan( - TensorView* in_tv, +NVF_API ScanResult scan( + TensorView* tv, int64_t dim, BinaryOpType op_type, - Val* init = nullptr); - -//! This is an alias for scan(tv, dim, BinaryOpType::Add, zeroVal()) -NVF_API TensorView* prefixSum(TensorView* tv, int64_t dim); + Val* init = nullptr, + Val* discount_factor = nullptr, + bool return_exclusive = false, + bool return_reduction = false); + +//! This is an alias for scan(tv, dim, BinaryOpType::Add, zeroVal(), +//! discount_factor) +NVF_API TensorView* prefixSum( + TensorView* tv, + int64_t dim, + Val* discount_factor = nullptr); //! Another alias for PyTorch's cumsum NVF_API inline TensorView* cumsum(TensorView* tv, int64_t dim) { diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 21af88fb710..56c51e65c46 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -156,11 +156,13 @@ IterType promoteIterType(IterType type1, IterType type2) { "Invalid IterType: ", type2); - // Do not propagate GatherScatter and VectorComponent - if (type1 == IterType::VectorComponent || type1 == IterType::GatherScatter) { + // Do not propagate GatherScatter, VectorComponent, or Scan + if (type1 == IterType::VectorComponent || type1 == IterType::GatherScatter || + type1 == IterType::Scan) { type1 = IterType::Iteration; } - if (type2 == IterType::VectorComponent || type2 == IterType::GatherScatter) { + if (type2 == IterType::VectorComponent || type2 == IterType::GatherScatter || + type2 == IterType::Scan) { type2 = IterType::Iteration; } @@ -426,6 +428,7 @@ IterDomain* newOutputIterDomain( .iter_type(IterType::Broadcast) .build(); } + NVF_ERROR(out_domain != nullptr); return out_domain; } #if defined(__GNUC__) && !defined(__clang__) @@ -451,7 +454,7 @@ std::vector newOutputDomain(const std::vector& vals) { input_ids.reserve(tvs.size()); for (auto* tv : tvs) { auto dom = TensorDomain::noReductions(tv->getLogicalDomain()); - input_ids.emplace_back(dom[dim_i]); + input_ids.emplace_back(dom.at(dim_i)); } out_domain[dim_i] = newOutputIterDomain(input_ids); } diff --git a/csrc/type.cpp b/csrc/type.cpp index fe34c781231..adb727b217f 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -613,6 +613,10 @@ static const char* binary_op_type2string(BinaryOpType t) { return "lessThan"; case BinaryOpType::NE: return "notEqual"; + case BinaryOpType::LHS: + return "lhs"; + case BinaryOpType::RHS: + return "rhs"; default: NVF_THROW("No string found for binary op type."); } @@ -852,6 +856,8 @@ static const char* iter_type2string(IterType t) { return "n"; case IterType::VectorComponent: return "v"; + case IterType::Scan: + return "c"; case IterType::Symbolic: return "?"; default: diff --git a/csrc/type.h b/csrc/type.h index 685a595708a..9b9857d2724 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -634,7 +634,11 @@ enum class BinaryOpType { LogicalOr, // generate complex from real and imaginary parts - Complex + Complex, + + // These just return one or the other of the arguments + LHS, + RHS }; enum class ScatterOpType { Set }; @@ -714,7 +718,8 @@ enum class IterType { Stride, GatherScatter, VectorComponent, - Symbolic + Symbolic, + Scan }; // Used for Iteration Domain mapping modes in ComputeAtMap diff --git a/runtime/helpers.cu b/runtime/helpers.cu index 91805e2a1aa..bd63b1b37ec 100644 --- a/runtime/helpers.cu +++ b/runtime/helpers.cu @@ -384,6 +384,16 @@ __device__ T gcd(T a, T b) { return a; } +template +__device__ T lhs(T a, U b) { + return a; +} + +template +__device__ U rhs(T a, U b) { + return b; +} + template bool isfinite(T x) { return ::isfinite(x); diff --git a/tests/cpp/test_compute_with.cpp b/tests/cpp/test_compute_with.cpp index 2d2e2ef150b..c27fa08c7a7 100644 --- a/tests/cpp/test_compute_with.cpp +++ b/tests/cpp/test_compute_with.cpp @@ -200,6 +200,7 @@ TEST_F(ComputeWithTest, ComputeWith2) { auto output_tv4 = div(exp_tv1_copy, bcast_sum_tv3); fusion.addOutput(output_tv4); + fusion.printMath(); auto input_tv0_cache = input_tv0->cacheAfter(); @@ -218,6 +219,8 @@ TEST_F(ComputeWithTest, ComputeWith2) { input_tv0_cache->axis(1)->parallelize(ParallelType::TIDx); scheduler_utils::parallelizeAllLike(input_tv0_cache); + fusion.printMath(); + GpuLower gpulw(&fusion); // Lowering should automatcially pick the first consumer of the // computed-with tensor as its target diff --git a/tests/cpp/test_scan.cpp b/tests/cpp/test_scan.cpp index 5ca6393cc2b..c1b27f78b98 100644 --- a/tests/cpp/test_scan.cpp +++ b/tests/cpp/test_scan.cpp @@ -7,13 +7,15 @@ // clang-format on #include #include -#include -#include #include #include +#include #include #include +#include +#include + namespace nvfuser { using ScanTest = NVFuserTest; @@ -26,7 +28,7 @@ TEST_F(ScanTest, BasicScanAdd) { auto tv0 = makeConcreteTensor({4, 8}); fusion.addInput(tv0); - auto tv_result = scan(tv0, /*dim=*/1, BinaryOpType::Add); + auto tv_result = scan(tv0, /*dim=*/1, BinaryOpType::Add).inclusive; fusion.addOutput(tv_result); fusion.printMath(); @@ -48,7 +50,7 @@ TEST_F(ScanTest, BasicScanMax) { auto tv0 = makeConcreteTensor({4, 8}); fusion.addInput(tv0); - auto tv_result = scan(tv0, /*dim=*/1, BinaryOpType::Max); + auto tv_result = scan(tv0, /*dim=*/1, BinaryOpType::Max).inclusive; fusion.addOutput(tv_result); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -68,7 +70,7 @@ TEST_F(ScanTest, BasicScanMin) { auto tv0 = makeConcreteTensor({4, 8}); fusion.addInput(tv0); - auto tv_result = scan(tv0, /*dim=*/1, BinaryOpType::Min); + auto tv_result = scan(tv0, /*dim=*/1, BinaryOpType::Min).inclusive; fusion.addOutput(tv_result); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -88,7 +90,7 @@ TEST_F(ScanTest, BasicScanMul) { auto tv0 = makeConcreteTensor({4, 8}); fusion.addInput(tv0); - auto tv_result = scan(tv0, /*dim=*/1, BinaryOpType::Mul); + auto tv_result = scan(tv0, /*dim=*/1, BinaryOpType::Mul).inclusive; fusion.addOutput(tv_result); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -108,7 +110,7 @@ TEST_F(ScanTest, ScanDifferentDimensions) { auto tv0 = makeConcreteTensor({2, 4, 6}); fusion.addInput(tv0); - auto tv_result = scan(tv0, /*dim=*/0, BinaryOpType::Add); + auto tv_result = scan(tv0, /*dim=*/0, BinaryOpType::Add).inclusive; fusion.addOutput(tv_result); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -128,7 +130,7 @@ TEST_F(ScanTest, Scan1D) { auto tv0 = makeConcreteTensor({10}); fusion.addInput(tv0); - auto tv_result = scan(tv0, /*dim=*/0, BinaryOpType::Add); + auto tv_result = scan(tv0, /*dim=*/0, BinaryOpType::Add).inclusive; fusion.addOutput(tv_result); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -157,7 +159,7 @@ TEST_F(ScanTest, ScanWithSimpleArithmetic) { auto tv1 = add(tv0, IrBuilder::create(1.0)); // Scan operation - auto tv2 = scan(tv1, /*dim=*/1, BinaryOpType::Add); + auto tv2 = scan(tv1, /*dim=*/1, BinaryOpType::Add).inclusive; fusion.addOutput(tv2); @@ -186,7 +188,7 @@ TEST_F(ScanTest, ScanWithArithmeticOps) { auto tv3 = sub(tv2, IrBuilder::create(0.5)); // Scan operation - auto tv4 = scan(tv3, /*dim=*/1, BinaryOpType::Add); + auto tv4 = scan(tv3, /*dim=*/1, BinaryOpType::Add).inclusive; // Additional operation after scan auto tv5 = div(tv4, IrBuilder::create(3.0)); @@ -202,4 +204,522 @@ TEST_F(ScanTest, ScanWithArithmeticOps) { testValidate(executor_cache.fusion(), outputs, {input}, __LINE__, __FILE__); } +// Simple test case for defining a scan +TEST_F(ScanTest, Concrete1D) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeConcreteTensor({32}); + fusion->addInput(tv0); + + auto tv1 = prefixSum(tv0, /*dim=*/-1, /*discount_factor=*/nullptr); + + fusion->addOutput(tv1); + + tv0->cacheAfter(); + tv1->cacheBefore(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({32}, options); + + KernelExecutor ke; + ke.compile(fusion.get(), {t0}); + + auto cg_outputs = ke.run({t0}); + + testValidate(fusion.get(), cg_outputs, {t0}, __LINE__, __FILE__); +} + +TEST_F(ScanTest, DiscountFactorScalar1D) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeConcreteTensor({32}); + fusion->addInput(tv0); + + auto tv1 = prefixSum( + tv0, + /*dim=*/-1, + /*discount_factor=*/IrBuilder::create(0.8, DataType::Float)); + + fusion->addOutput(tv1); + + tv0->cacheAfter(); + tv1->cacheBefore(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({32}, options); + + KernelExecutor ke; + ke.compile(fusion.get(), {t0}); + + auto cg_outputs = ke.run({t0}); + + testValidate(fusion.get(), cg_outputs, {t0}, __LINE__, __FILE__); +} + +TEST_F(ScanTest, DiscountFactorTensor1D) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeConcreteTensor({32}); + fusion->addInput(tv0); + auto tv1 = makeConcreteTensor({32}); + fusion->addInput(tv1); + + auto tv2 = prefixSum( + tv0, + /*dim=*/-1, + /*discount_factor=*/tv1); + + fusion->addOutput(tv2); + + tv0->cacheAfter(); + tv2->cacheBefore(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({32}, options); + at::Tensor t1 = at::randn({32}, options).exp(); + + KernelExecutor ke; + ke.compile(fusion.get(), {t0, t1}); + + auto cg_outputs = ke.run({t0, t1}); + + testValidate(fusion.get(), cg_outputs, {t0, t1}, __LINE__, __FILE__); +} + +// Simple test case for defining a scan +TEST_F(ScanTest, Concrete2D) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeConcreteTensor({4, 32}); + fusion->addInput(tv0); + + auto tv1 = prefixSum(tv0, /*dim=*/0, /*discount_factor=*/nullptr); + + fusion->addOutput(tv1); + + tv0->cacheAfter(); + tv1->cacheBefore(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({4, 32}, options); + + KernelExecutor ke; + ke.compile(fusion.get(), {t0}); + + auto cg_outputs = ke.run({t0}); + + testValidate(fusion.get(), cg_outputs, {t0}, __LINE__, __FILE__); +} + +// This is similar to what's needed for serial online softmax +TEST_F(ScanTest, OnlineSoftmax) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto x = makeSymbolicTensor(1); + fusion->addInput(x); + + int64_t scan_dim = 0; + + // Online normalizer for softmax: https://arxiv.org/abs/1805.02867 + // + // Given x[i] for i=0 .. N-1: + // + // m[-1] = -infinity + // d[-1] = 0 + // for j = 0 .. N-1 + // m[j] = max(m[j-1], x[j]) + // d[j] = d[j-1] * exp(m[j-1] - m[j]) + exp(x[j] - m[j]) + // + // Final denominator is d[N-1] + + auto* neg_infty = IrBuilder::create( + -std::numeric_limits::infinity(), DataType::Double); + ScanResult max_scan_result = scan( + set(x), + scan_dim, + BinaryOpType::Max, + /*init=*/neg_infty, + /*discount_factor=*/nullptr, + /*return_exclusive=*/true); // max x[j] over j = 0 .. i + TensorView* m = max_scan_result.inclusive; + TensorView* m_prev = max_scan_result.exclusive; + // normalize by running max and exponentiate + TensorView* exp_x_m = exp(sub(x, m)); + // Discount factor is exponentiated delta: exp(m[i-1] - m[i]) + TensorView* discount = exp(sub(m_prev, m)); + + auto denoms = prefixSum(exp_x_m, scan_dim, discount); + + auto norm_factor = reductionOp( + BinaryOpType::RHS, + {scan_dim}, + /*init=*/fusion->zeroVal(DataType::Float), + denoms); + + auto full_max = reductionOp( + BinaryOpType::RHS, + {scan_dim}, + /*init=*/neg_infty, + m); + + auto max_bcast = broadcast(full_max, {true}); + auto norm_factor_bcast = broadcast(norm_factor, {true}); + // Recompute numerator + auto numer = exp(sub(set(x), max_bcast)); + + auto result = div(numer, norm_factor_bcast); + + fusion->addOutput(result); + + // Don't cache inputs for this fusion because we will need to recompute + // exp((x-m)) using the final max in a separate loop so caching would mean + // we'd need to hold the whole tensor in registers. Instead, we manually call + // set(x) twice in the definition above to give us two separate caches. + scheduler_utils::cacheAndForkOutputs(fusion.get(), /*unroll=*/true); + + // We don't inline the scans past the scan dimension + std::unordered_set uninlineable_ids; + for (TensorView* tv : {m, m_prev, denoms}) { + for (IterDomain* id : tv->getLoopDomain()) { + uninlineable_ids.insert(id); + } + } + + inlineMost(uninlineable_ids); + + // These TVs are not inlined, but instead we set computeWith on them + for (TensorView* tv : {m, m_prev, denoms}) { + tv->computeWith(-1); + for (Val* v : tv->definition()->inputs()) { + // By using `uninlineable_ids` above, we prevent producers of scan from + // inlining with the ScanOp past the scan dim, even though this is + // desired. Here we do this inlining manually instead. + v->as()->inlineAt(-1); + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({32}, options); + + KernelExecutor ke; + ke.compile(fusion.get(), {t0}); + + auto cg_outputs = ke.run({t0}); + + auto ref = at::softmax(t0, 0); + EXPECT_TRUE(at::allclose(cg_outputs[0].as(), ref)) + << " returned " << cg_outputs[0].as().item() + << " but expected " << ref.item(); + + // Test automatic evaluation also + testValidate(fusion.get(), cg_outputs, {t0}, __LINE__, __FILE__); +} + +// +TEST_F(ScanTest, OnlineSoftmaxOuter) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // TODO: Allow outer dim to be symbolic + auto x = makeConcreteTensor({-1, 32}); + fusion->addInput(x); + + int64_t scan_dim = 0; + + // Online normalizer for softmax: https://arxiv.org/abs/1805.02867 + // + // Given x[i] for i=0 .. N-1: + // + // m[-1] = -infinity + // d[-1] = 0 + // for j = 0 .. N-1 + // m[j] = max(m[j-1], x[j]) + // d[j] = d[j-1] * exp(m[j-1] - m[j]) + exp(x[j] - m[j]) + // + // Final denominator is d[N-1] + + auto* neg_infty = IrBuilder::create( + -std::numeric_limits::infinity(), DataType::Double); + ScanResult max_scan_result = scan( + set(x), + scan_dim, + BinaryOpType::Max, + /*init=*/neg_infty, + /*discount_factor=*/nullptr, + /*return_exclusive=*/true); // max x[j] over j = 0 .. i + TensorView* m = max_scan_result.inclusive; + TensorView* m_prev = max_scan_result.exclusive; + // normalize by running max and exponentiate + TensorView* exp_x_m = exp(sub(x, m)); + // Discount factor is exponentiated delta: exp(m[i-1] - m[i]) + TensorView* discount = exp(sub(m_prev, m)); + + auto denoms = prefixSum(exp_x_m, scan_dim, discount); + + auto norm_factor = reductionOp( + BinaryOpType::RHS, + {scan_dim}, + /*init=*/fusion->zeroVal(DataType::Float), + denoms); + + fusion->addOutput(norm_factor); + + scheduler_utils::cacheInputs(fusion.get(), /*unroll=*/true); + scheduler_utils::cacheAndForkOutputs(fusion.get(), /*unroll=*/true); + + // We don't inline the scans past the scan dimension + std::unordered_set uninlineable_ids; + for (TensorView* tv : {m, m_prev, denoms}) { + for (IterDomain* id : tv->getLoopDomain()) { + uninlineable_ids.insert(id); + } + } + + inlineMost(uninlineable_ids); + + // These TVs are not inlined, but instead we set computeWith on them + for (TensorView* tv : {m, m_prev, denoms}) { + tv->computeWith(-1); + for (Val* v : tv->definition()->inputs()) { + // By using `uninlineable_ids` above, we prevent producers of scan from + // inlining with the ScanOp past the scan dim, even though this is + // desired. Here we do this inlining manually instead. + v->as()->inlineAt(-1); + } + } + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({16, 32}, options); + + KernelExecutor ke; + ke.compile(fusion.get(), {t0}); + + auto cg_outputs = ke.run({t0}); + + at::Tensor ref = (t0 - std::get<0>(t0.max(/*dim=*/0, /*keepdim=*/true))) + .exp() + .sum(/*dim=*/0); + EXPECT_TRUE(at::allclose(cg_outputs[0].as(), ref)) + << " returned " << cg_outputs[0].as() << " but expected " + << ref; + + // Test automatic evaluation also + testValidate(fusion.get(), cg_outputs, {t0}, __LINE__, __FILE__); +} + +// This is a simplified version of FlashAttention that does not circular buffer +// the inputs or use mma instructions, but has the same general computation +// pattern. +// +// Dao et al. 2022. FlashAttention: Fast and Memory-Efficient Exact Attention +// with IO-Awareness. https://arxiv.org/abs/2205.14135 +TEST_F(ScanTest, FlashAttentionNoMma) { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + // Inputs are Q, K, V + // Normally each of these would be 2D, shaped N-by-d + // Output is softmax(Q@K.T, dim=1)@V + // + // Here, I am hard-coding a tiling split and transpose: + // + // Q: N1o, d1o, N1i, d1i + // K: N2o, d1o, d1i, N2i + // V: N2o, d2o, d2i, N2i + // + // For the Q@K.T matmul, we have the following dim roles: + // + // M: N1o, N1i + // N: N2o, N2i + // K: d1o, d1i + // + // For the second matmul S@V the roles are: + // + // M: N1o, N1i + // N: d2o, d2i + // K: N2o, N2i + // + // The overall output should be of size N1o, N1i, d2o, d2i since it is the + // result of that final matmul + // + // The general strategy for ordinary flash attention is to have two main + // serial loops: over the outer "K" dims d1o and N2o. The inner dims are + // computed via a tile matmul, equivalent to two more loops d1i and N2i. + // + // We grid parallelize across CTAs in the N1o and d2o dimensions and each CTA + // computes a final output of size N1i by d2i. + int64_t N1o = 2; + int64_t N1i = 3; + int64_t N2o = 4; + int64_t N2i = 5; + int64_t d1o = 6; + int64_t d1i = 7; + int64_t d2o = 8; + int64_t d2i = 9; + + // [N1o, N2o, d1o, d2o, N1i, N2i, d1i, d2i] + auto Q = makeConcreteTensor({N1o, 1, d1o, 1, N1i, 1, d1i, 1}); + auto K = makeConcreteTensor({1, N2o, d1o, 1, 1, N2i, d1i, 1}); + auto V = makeConcreteTensor({1, N2o, 1, d2o, 1, N2i, 1, d2i}); + fusion->addInput(Q); + fusion->addInput(K); + fusion->addInput(V); + + // Notation is from Algorithm 1 of Dao et al. 2022 + + // TODO: mma + auto S = + sum(mul(Q, K), + /*dims=*/{-2}, + /*keep_dim=*/true); // [N1o, N2o, d1o, 1, N1i, N2i, (1), 1] + + auto m_tilde = + max(S, {-3}, /*keep_dim=*/true); // [N1o, N2o, d1o, 1, N1i, (1), 1, 1] + + auto* neg_infty = IrBuilder::create( + -std::numeric_limits::infinity(), DataType::Double); + ScanResult max_scan_result = scan( + m_tilde, + 2, + BinaryOpType::Max, + /*init=*/neg_infty, + /*discount_factor=*/nullptr, + /*return_exclusive=*/true); + TensorView* m = + max_scan_result.inclusive; // [N1o, N2o, (d1o), 1, N1i, 1, 1, 1] + TensorView* m_prev = max_scan_result.exclusive; + + auto P_tilde = exp(sub(S, m_tilde)); // [N1o, N2o, d1o, 1, N1i, N2i, 1, 1] + + auto l_tilde = + sum(P_tilde, + {-3}, + /*keep_dim=*/true); // [N1o, N2o, d1o, 1, N1i, (1), 1, 1] + + auto first_discount = exp(sub(m_prev, m)); // [N1o, N2o, d1o, 1, N1i, 1, 1, 1] + + auto l_tilde_factor = + exp(sub(m_tilde, m)); // [N1o, N2o, d1o, 1, N1i, 1, 1, d2i] + auto next_l = + mul(l_tilde_factor, l_tilde); // [N1o, N2o, d1o, 1, N1i, 1, 1, d2i] + + ScanResult sum_scan_result = scan( + next_l, + 2, + BinaryOpType::Add, + /*init=*/fusion->zeroVal(DataType::Float), + /*discount_factor=*/first_discount, + /*return_exclusive=*/true); + TensorView* l = + sum_scan_result.inclusive; // [N1o, N2o, (d1o), 1, N1i, 1, 1, 1] + TensorView* l_prev = sum_scan_result.exclusive; + + auto O_discount = + mul(div(l_prev, l), first_discount); // [N1o, N2o, d1o, 1, N1i, 1, 1, 1] + + // P_tilde = [N1o, N2o, d1o, 1, N1i, N2i, 1, d2i] + // V = [1, N2o, 1, d2o, 1, N2i, 1, d2i] + auto PtildeV = + sum(mul(P_tilde, V), + /*dims=*/{-3}, + /*keep_dim=*/true); // [N1o, N2o, d1o, d2o, N1i, (1), 1, d2i] + + auto O = prefixSum( + mul(div(l_tilde_factor, l), PtildeV), + 2, + /*discount_factor=*/O_discount); // [N1o, N2o, (d1o), d20, N1i, 1, 1, d2i] + + auto O_final = reductionOp( + BinaryOpType::RHS, + {1, 2}, + /*init=*/fusion->zeroVal(DataType::Float), + set(O), // TODO: is this set really needed to avoid computeWith error on + // O? + /*keepdim=*/true); // [N1o, (1), (1), d2o, N1i, 1, 1, d2i] + + fusion->addOutput(O_final); + + fusion->printMath(); + + // We don't inline the scans past the scan dimension + std::unordered_set uninlineable_ids; + for (Expr* expr : fusion->exprs()) { + if (expr->isA()) { + for (auto tv : ir_utils::filterByType(expr->outputs())) { + for (IterDomain* id : tv->getLoopDomain()) { + uninlineable_ids.insert(id); + } + } + } + } + + Q->cacheAfter(); + K->cacheAfter(); + V->cacheAfter(); + O_final->cacheBefore(); + + inlineMost(uninlineable_ids); + + // These TVs are not inlined, but instead we set computeWith on them + for (Expr* expr : fusion->exprs()) { + if (expr->isA()) { + expr->output(0)->as()->computeWith(-1); + for (Val* v : expr->inputs()) { + // By using `uninlineable_ids` above, we prevent producers of scan from + // inlining with the ScanOp past the scan dim, even though this is + // desired. Here we do this inlining manually instead. + v->as()->inlineAt(-1); + } + } + } + + O_final->axis(0)->parallelize(ParallelType::BIDx); + O_final->axis(3)->parallelize(ParallelType::BIDy); + scheduler_utils::parallelizeAllLike(O_final); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor q = at::randn({N1o, 1, d1o, 1, N1i, 1, d1i, 1}, options); + at::Tensor k = at::randn({1, N2o, d1o, 1, 1, N2i, d1i, 1}, options); + at::Tensor v = at::randn({1, N2o, 1, d2o, 1, N2i, 1, d2i}, options); + std::vector inputs{q, k, v}; + + auto qorig = q.transpose(2, 4).reshape({N1o * N1i, d1o * d1i}); // 6, 42 + auto korig = k.transpose(2, 5).reshape({N2o * N2i, d1o * d1i}); // 20, 42 + auto vorig = v.transpose(3, 5).reshape({N2o * N2i, d2o * d2i}); // 20, 72 + auto qktref = at::matmul(qorig, korig.t()); // 6, 20 + auto sref = at::softmax(qktref, 1); // 6, 20 + auto ref = + at::matmul(sref, vorig) + .reshape({N1o, 1, 1, d2o, N1i, 1, 1, d2i}); // 2, 1, 1, 8, 3, 1, 1, 9 + + KernelExecutor ke; + ke.compile(fusion.get(), inputs); + + auto cg_outputs = ke.run(inputs); + + EXPECT_TRUE( + at::allclose(cg_outputs[0].as().squeeze(), ref.squeeze())); + //<< " returned " << cg_outputs[0].as()[0] << " but expected " + //<< ref[0]; +} + } // namespace nvfuser