Skip to content

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Apr 8, 2025

ScanTest.OnlineSoftmax currently passes and generates this code (as of db5798c)

// Codegen generated code
__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 1, 1> T0, Tensor<float, 0, 0> T8) {
  NVFUSER_DEFINE_MAGIC_ZERO;
  Array<float, 1, 1> T1;
  Array<float, 1, 1> T2;
  Array<float, 1, 1> T7;
  Array<float, 1, 1> T10;
  T10[0] = 0.000000000e+00f;
  #pragma unroll
  for(nvfuser_index_t i0 = 0; i0 < 32; ++i0) {
    Array<float, 1, 1> T9;
    T9[0] = 0;
    T9[0]
       = T0[(T0.alloc_stride[0LL] * (i0 + nvfuser_zero))];
    T2[0] = ((i0 > 0) ? (T1[0]) : NEG_INFINITY);
    T1[0] = fmax(
      T2[0],
      (T9[0]));
    Array<float, 1, 1> T3;
    T3[0]
      = T9[0]
      - T1[0];
    Array<float, 1, 1> T5;
    T5[0]
      = T2[0]
      - T1[0];
    Array<float, 1, 1> T6;
    T6[0]
       = expf(T5[0]);
    Array<float, 1, 1> T4;
    T4[0]
       = expf(T3[0]);
    T7[0]
      = (T6[0] * ((i0 > 0) ? (T7[0]) : 0.000000000e+00f))
      + (T4[0]);
    T10[0] = rhs(
      T10[0],
      T7[0]);
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  T8[0]
     = T10[0];
}

This is the softmax denominator only.

Next I will:

  • Fix inlining of intermediates as mentioned above
  • Add validation
  • Add more tests for corner cases
  • Add a mock FlashAttention test

@github-actions
Copy link

github-actions bot commented Apr 8, 2025

Review updated until commit 302015c

Description

  • Introduced ScanOp for scan operations with support for exclusive and reduction outputs

  • Added validation and allocation logic for ScanOp

  • Updated tests to include various scan cases, including online softmax and flash attention

  • Enhanced TensorDomain and IterDomain to support Scan type


Changes walkthrough 📝

Relevant files
Enhancement
20 files
lower2device.cpp
Add validation for ScanOp                                                               
+3/-0     
allocation.cpp
Update allocation logic for ScanOp                                             
+16/-2   
index.cpp
Implement indexing for ScanOp                                                       
+82/-0   
utils.cpp
Add ScanOp to TV operation check                                                 
+1/-0     
validation.cpp
Add validation for ScanOp scheduling                                         
+72/-0   
indexing.cpp
Update override index handling in TensorIndexer                   
+1/-1     
index_compute.cpp
Add override_index_ids to getConsumerIndex                             
+6/-1     
nodes.cpp
Enhance ScanOp to support exclusive and reduction outputs
+176/-26
logical_domain_map.cpp
Handle ScanOp in ComputeAtLogicalDomainMapBuilder               
+1/-1     
arith.cpp
Update scan and prefixSum to support discount factor and additional
outputs
+65/-21 
utils.cpp
Update promoteIterType and newOutputIterDomain for Scan   
+7/-4     
type.cpp
Add string representations for LHS, RHS, and Scan IterType
+6/-0     
helpers.cu
Add lhs and rhs device functions                                                 
+10/-0   
index.h
Add handle for ScanOp in IndexLowering                                     
+1/-0     
validation.h
Declare validateScans function                                                     
+4/-0     
internal_base_nodes.h
Add isScan method to IterDomain                                                   
+5/-0     
internal_nodes.h
Enhance ScanOp to support exclusive and reduction outputs
+43/-10 
logical_domain_map.h
Add handle for ScanOp in ComputeAtLogicalDomainMapBuilder
+4/-0     
arith.h
Update scan and prefixSum to support discount factor and additional
outputs
+34/-8   
type.h
Add LHS, RHS, and Scan to BinaryOpType and IterType           
+7/-2     
Bug fix
1 files
expr_evaluator.cpp
Ensure output size matches definition outputs                       
+1/-0     
Documentation
1 files
test_compute_with.cpp
Add printMath calls for debugging                                               
+3/-0     
Tests
1 files
test_scan.cpp
Add comprehensive tests for ScanOp including online softmax and flash
attention
+530/-10
Miscellaneous
1 files
dispatch.h
Reorder ScanOp in dispatch macro                                                 
+1/-1     
Configuration changes
1 files
CMakeLists.txt
Add test_scan.cpp to JIT_TEST_SRCS                                             
+1/-0     

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Allocation Logic

The logic for allocating IterType::Scan IDs seems correct, but it would be good to ensure that this does not lead to unnecessary memory allocations or performance degradation.

if (id->isScan()) {
  // Allocate IterType::Scan IDs only if they are outputs or not computeWith.
  // We know tv must not have computeAt past the scan id, so without
  // computeWith, we know the expression won't be inlined and we'll need to
  // allocate the scan id.
  if (tv->isFusionOutput()) {
    return true;
  }
  return !tv->hasComputeWith();
}
Validation Logic

The validation logic for scans is comprehensive, but it would be beneficial to add more test cases to ensure all edge cases are covered, especially regarding the interaction with other operations and scheduling.

// When we compute a scan we compute something like the following
//
//   Array<float, 1, 1> T3;
//   #pragma unroll
//   for(nvfuser_index_t i1 = 0; i1 < 32; ++i1) {
//     T3[0]
//       = ((i1 > 0) ? (T3[0]) : 0.000000000e+00f)
//       + (T2[i1]);
//   }
//
// We need to validate that T3 is not inlined past the scan axis (the i1 loop
// in this case), because its lifetime must persist beyond the scan loop. Note
// that it is permissible to use `computeWith` as in this example to move the
// computed position inside the scan loop, alleviating the need to allocate an
// axis of size 32 in this case.
//
//
// Validate:
// 1. Outputs are inlined with all their uses past the scan dim
// 2. Discount factor and input are computed with this expression past the
//   scan dim
// 3. Outputs are not inlined past the scan dimension, as this we require
//   the scan outputs to be allocated outside the scan loop
void validateScans(Fusion* fusion) {
  for (auto sop : ir_utils::getOpsOfType<ScanOp>(fusion)) {
    auto* out = sop->out()->as<TensorView>();
    TensorView* out_exclusive = sop->outExclusive();

    // Find position of scan dim in loop domain
    IterDomain* scan_id = out->getLogicalDomain().at((size_t)sop->scanDim());
    int64_t scan_pos = -1L;
    for (int64_t pos : arange(out->nDims())) {
      if (out->axis(pos) == scan_id) {
        scan_pos = pos;
        break;
      }
    }
    NVF_ERROR(
        scan_pos != -1L,
        "Could not find scan dimension ",
        scan_id->toString(),
        " in loop domain. Scan dimensions must not be scheduled with splits or "
        "merges");

    const auto check_uses = [&](TensorView* output) {
      for (Expr* use : output->uses()) {
        for (Val* outp : use->outputs()) {
          if (auto* out_tv = dynamic_cast<TensorView*>(outp)) {
            NVF_ERROR(
                out_tv->getComputeWithPosition() >= scan_pos,
                "Use of output, ",
                use->toString(),
                " must have all outputs inlined or computeWith to or past scan "
                "position ",
                scan_pos);
          }
        }
      }
    };

    check_uses(out);

    if (out_exclusive != nullptr) {
      NVF_ERROR(out_exclusive->getComputeWithPosition() >= scan_pos);
      NVF_ERROR(out_exclusive->getComputeAtPosition() <= scan_pos);
      check_uses(out_exclusive);
    }
    // Must have allocated outside scan loop
    NVF_ERROR(out->getComputeAtPosition() <= scan_pos);
  }
}

} // namespace nvfuser
Scan Implementation

The implementation of the scan operation seems correct, but it would be good to ensure that the handling of discount factors and different binary operations is robust and handles all possible input types and shapes.

ScanResult scan(
    TensorView* tv,
    int64_t dim,
    BinaryOpType op_type,
    Val* init,
    Val* discount_factor,
    bool return_exclusive,
    bool return_reduction) {
  const std::vector<IterDomain*> logical_dom =
      TensorDomain::noReductions(tv->getLogicalDomain());

  dim = wrapDim(dim, (int64_t)logical_dom.size());

  IterDomain* scan_id = logical_dom.at((size_t)dim);

  if (init == nullptr) {
    init = ops::binOpIdentity(op_type, tv->dtype());
    NVF_ERROR(init != nullptr);
  }

  // Special case: scanning along broadcast dimension is no-op
  // Assumes init is identity for op_type
  if (scan_id->isBroadcast()) {
    if (scan_id->hasExpandedExtent()) {
      NVF_THROW(
          "Closed-form scan of expanded dimension is not yet implemented");
    }
    // Exclusive scan is just the init val
    return {set(tv), mul(init, ones_like(tv))};
  }

  std::vector<IterDomain*> new_dom;
  DataType dtype = tv->dtype();
  if (discount_factor == nullptr) {
    new_dom = ops::newOutputDomain({tv});
  } else {
    new_dom = ops::newOutputDomain({tv, discount_factor});
    dtype = promoteType(tv->dtype(), discount_factor->dtype());
    tv = maybeCastOp(dtype, tv);
    discount_factor = maybeCastOp(dtype, discount_factor);
  }
  new_dom.at((size_t)dim) = IterDomainBuilder(new_dom.at((size_t)dim))
                                .iter_type(IterType::Scan)
                                .build();
  auto* td = IrBuilder::create<TensorDomain>(
      new_dom, TensorDomain::getContiguityFilledWith(new_dom, true));

  ScanResult result;

  result.inclusive = IrBuilder::create<TensorView>(td, tv->dtype());

  if (return_exclusive) {
    result.exclusive = ops::newOutputTV({result.inclusive}, tv->dtype());
  }

  if (return_reduction) {
    std::vector<IterDomain*> red_dom = ops::newOutputDomain({tv});
    red_dom.at((size_t)dim) = IterDomainBuilder(red_dom.at((size_t)dim))
                                  .iter_type(IterType::Reduction)
                                  .build();
    auto* red_td = IrBuilder::create<TensorDomain>(
        red_dom, TensorDomain::getContiguityFilledWith(red_dom, true));

    result.reduction = IrBuilder::create<TensorView>(red_td, tv->dtype());
  }

  IrBuilder::createInContainer<ScanOp>(
      tv->container(),
      op_type,
      result.inclusive,
      result.exclusive,
      result.reduction,
      tv,
      discount_factor,
      init,
      dim);

  return result;
}

TensorView* prefixSum(TensorView* tv, int64_t dim, Val* discount_factor) {
  return scan(
             tv,
             dim,
             BinaryOpType::Add,
             /*init=*/tv->fusion()->zeroVal(tv->dtype()),
             discount_factor)
      .inclusive;
}

} // namespace nvfuser

@jacobhinkle jacobhinkle changed the title WIP: PrefixSumOp WIP: ScanOp Apr 10, 2025
auto index_info = computeIndex(expr, indexed_ids, for_loops);
for (const auto& [indexed_id, index] : override_index) {
index_info.index_map.emplace(traversalGraph().toGroup(indexed_id), index);
index_info.index_map[traversalGraph().toGroup(indexed_id)] = index;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is a bug since emplace will simply fail and return false if there is a pre-existing entry, whereas we need to update the index here instead.

Comment on lines +2330 to +2336
std::unordered_map<IterDomain*, Val*> override_index_ids;
for (auto& [pos, idx] : override_index) {
override_index_ids.emplace(
consumer->getMaybeAllocationDomain().at((size_t)pos), idx);
}
index = GpuLower::current()->tensorIndexer().getLinearIndex(
consumer, consumer->definition(), loops);
consumer, consumer->definition(), loops, override_index_ids);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is necessary for getConsumerIndex to respect override_index.

Comment on lines 193 to 209
// We don't inline the scans past the scan dimension
std::unordered_set<IterDomain*> uninlineable_ids;
for (TensorView* tv : {m, m_prev, denoms}) {
for (IterDomain* id : tv->getLoopDomain()) {
uninlineable_ids.insert(id);
}
}

inlineMost(uninlineable_ids);

// These TVs are not inlined, but instead we set computeWith on them
for (TensorView* tv : {m, m_prev, denoms}) {
tv->computeWith(-1);
for (Val* v : tv->definition()->inputs()) {
v->as<TensorView>()->computeWith(-1);
}
}
Copy link
Collaborator Author

@jacobhinkle jacobhinkle Apr 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scheduling inlining with scan

The above schedule results in this fusion IR:

Inputs:
  T0_g_float[iS0{32}]
Outputs:
  T8_g_float[]

%kernel {
T9_l_float[iS11{32}] compute_with( 1 )
   = Set( T0_g_float[iS0{32}], cache_op=Streaming )
T1_l_float[cS2{32}] compute_with( 1 ) maybe_produce_pos( 1 ),
T2_l_float[cS3{32}] compute_with( 1 ) maybe_produce_pos( 1 )
   = scan(fmax,
          T9_l_float[iS11{32}] compute_with( 1 ),
          dim=0,
          init=double(-inf))
T3_l_float[iS4{32}] ca_pos( 1 ) maybe_produce_pos( 1 )
   = T9_l_float[iS11{32}] compute_with( 1 )
   - T1_l_float[cS2{32}] compute_with( 1 ) maybe_produce_pos( 1 );
T4_l_float[iS5{32}] compute_with( 1 ) produce_pos( 1 )
   = expf(T3_l_float[iS4{32}] ca_pos( 1 ) maybe_produce_pos( 1 ));
T5_l_float[iS6{32}] ca_pos( 1 ) maybe_produce_pos( 1 )
   = T2_l_float[cS3{32}] compute_with( 1 ) maybe_produce_pos( 1 )
   - T1_l_float[cS2{32}] compute_with( 1 ) maybe_produce_pos( 1 );
T6_l_float[iS7{32}] compute_with( 1 ) produce_pos( 1 )
   = expf(T5_l_float[iS6{32}] ca_pos( 1 ) maybe_produce_pos( 1 ));
T7_l_float[cS9{32}] compute_with( 1 ) maybe_produce_pos( 1 )
   = scan(add,
          T4_l_float[iS5{32}] compute_with( 1 ) produce_pos( 1 ),
          dim=0,
          discount_factor=T6_l_float[iS7{32}] compute_with( 1 ) produce_pos( 1 ),
          init=float(0))
T10_l_float[rS10{32}] maybe_produce_pos( 1 )
   = reduction( T7_l_float[cS9{32}] compute_with( 1 ) maybe_produce_pos( 1 ), op = rhs, initial value = float(0), allreduce = false )
T8_g_float[]
   = Set( T10_l_float[rS10{32}] maybe_produce_pos( 1 ), cache_op=Streaming )

and the following kernel:

// Codegen generated code
__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 1, 1> T0, Tensor<float, 0, 0> T8) {
  NVFUSER_DEFINE_MAGIC_ZERO;
  Array<float, 32, 1> T9;
  #pragma unroll
  for(nvfuser_index_t i0 = 0; i0 < 32; ++i0) {
    T9[i0] = 0;
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  Array<float, 1, 1> T1;
  Array<float, 1, 1> T2;
  Array<float, 32, 1> T6;
  Array<float, 32, 1> T4;
  Array<float, 1, 1> T7;
  Array<float, 1, 1> T10;
  T10[0] = 0.000000000e+00f;
  #pragma unroll
  for(nvfuser_index_t i0 = 0; i0 < 32; ++i0) {
    T9[i0]
       = T0[(T0.alloc_stride[0LL] * (i0 + nvfuser_zero))];
    T2[0] = ((i0 > 0) ? (T1[0]) : NEG_INFINITY);
    T1[0] = fmax(
      ((i0 > 0) ? (T1[0]) : NEG_INFINITY),
      (T9[i0]));
    Array<float, 1, 1> T5;
    T5[0]
      = T2[0]
      - T1[0];
    T6[i0]
       = expf(T5[0]);
    Array<float, 1, 1> T3;
    T3[0]
      = T9[i0]
      - T1[0];
    T4[i0]
       = expf(T3[0]);
    T7[0]
      = (T6[i0] * ((i0 > 0) ? (T7[0]) : 0.000000000e+00f))
      + (T4[i0]);
    T10[0] = rhs(
      T10[0],
      T7[0]);
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  T8[0]
     = T10[0];
}

I don't yet understand how I can get T9, T4, and T6 to be allocated with size 1 instead of 32. The ID model maps all IDs together in the exact map but not in the loop map (before resolveComputeWith):

IterDomainGraphs {
  IdGraph exact{
  Disjoint Ids:
    (idgs){
      idg{0 2 3 4 5 6 7 9 10 11}
}
  Disjoint Expression groups:
    (exprgs){
    }
   } IdGraph

  IdGraph almost_exact{
  Disjoint Ids:
    (idgs){
      idg{0 2 3 4 5 6 7 9 10 11}
}
  Disjoint Expression groups:
    (exprgs){
    }
   } IdGraph

  IdGraph broadcast{
  Disjoint Ids:
    (idgs){
      idg{0 2 3 4 5 6 7 9 10 11}
}
  Disjoint Expression groups:
    (exprgs){
    }
   } IdGraph

  IdGraph permissive{
  Disjoint Ids:
    (idgs){
      idg{0 2 3 4 5 6 7 9 10 11}
}
  Disjoint Expression groups:
    (exprgs){
    }
   } IdGraph

  IdGraph loop{
  Disjoint Ids:
    (idgs){
      idg{0}
      idg{2 3}
      idg{4 5}
      idg{6 7}
      idg{9}
      idg{10}
      idg{11}
}
  Disjoint Expression groups:
    (exprgs){
    }
   } IdGraph

 } IterDomainGraphs

I am wondering:

  • why the loop graph is finer than exact this is just because ca_pos is 0 for a lot of these tensors due to uninlineable_ids.
  • how can I control inlining such that producers of a scan are inlined past the scan dimension so that I don't need to use computeWith on them. I think uninlineable_ids prevents inlining with both consumers and producers but I'd like to still inline the producers with it.
  • Whether there is a better way to force allocation at a higher scope but still inline the expression
  • What exactly is "storeAt" referred to in the code comments

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't yet understand how I can get T9, T4, and T6 to be allocated with size 1 instead of 32.

Because their computeAt positions are all 0, they are allocated entirely. The computeWith position only changes the loop where the expression appears.

why the loop graph is finer than exact

Loop only maps when inlined, so by definition, it should be finer.

how can I control inlining such that producers of a scan are inlined past the scan dimension so that I don't need to use computeWith on them.

I'm not sure as I'm not sure how the scan is implemented, but if it's like reduction, don't we need to update MaxPosInliner?

Whether there is another way to force allocation at a higher scope but still inline the expression

This is exactly what computeWith is supposed to address.

What exactly is "storeAt" referred to in the code comments

Probably the same concept as store_at in Halide.

This test runs. I'm not sure if it's correct yet, but there is one main
loop so that's nice
//
// Dao et al. 2022. FlashAttention: Fast and Memory-Efficient Exact Attention
// with IO-Awareness. https://arxiv.org/abs/2205.14135
TEST_F(ScanTest, FlashAttentionNoMma) {
Copy link
Collaborator Author

@jacobhinkle jacobhinkle Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of deaff1f this generates the following code:

// Codegen generated code
__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 8, 8> T0, Tensor<float, 8, 8> T1, Tensor<float, 8, 8> T2, Tensor<float, 8, 8> T31) {
  NVFUSER_DEFINE_MAGIC_ZERO;
  nvfuser_index_t i0;
  i0 = T0.alloc_stride[0LL] * ((nvfuser_index_t)blockIdx.x);
  nvfuser_index_t i1;
  i1 = T2.alloc_stride[3LL] * ((nvfuser_index_t)blockIdx.y);
  nvfuser_index_t i2;
  i2 = (27 * ((nvfuser_index_t)blockIdx.y)) + (216 * ((nvfuser_index_t)blockIdx.x));
  Array<float, 12, 1> T8;
  Array<float, 12, 1> T9;
  Array<float, 12, 1> T19;
  Array<float, 12, 1> T20;
  Array<float, 108, 1> T28;
  Array<float, 27, 1> T30;
  #pragma unroll
  for(nvfuser_index_t i3 = 0; i3 < 3; ++i3) {
    nvfuser_index_t i4;
    i4 = 9 * i3;
    #pragma unroll
    for(nvfuser_index_t i5 = 0; i5 < 9; ++i5) {
      T30[(i4 + i5)] = 0.000000000e+00f;
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  #pragma unroll
  for(nvfuser_index_t i6 = 0; i6 < 4; ++i6) {
    nvfuser_index_t i7;
    i7 = T1.alloc_stride[1LL] * i6;
    nvfuser_index_t i8;
    i8 = 3 * i6;
    nvfuser_index_t i9;
    i9 = i1 + (T2.alloc_stride[1LL] * i6);
    nvfuser_index_t i10;
    i10 = 27 * i6;
    #pragma unroll
    for(nvfuser_index_t i11 = 0; i11 < 6; ++i11) {
      nvfuser_index_t i12;
      i12 = i7 + (T1.alloc_stride[2LL] * i11);
      nvfuser_index_t i13;
      i13 = i0 + (T0.alloc_stride[2LL] * i11);
      #pragma unroll
      for(nvfuser_index_t i3 = 0; i3 < 3; ++i3) {
        nvfuser_index_t i14;
        i14 = i13 + (T0.alloc_stride[4LL] * i3);
        nvfuser_index_t i15;
        i15 = i8 + i3;
        nvfuser_index_t i16;
        i16 = 9 * i3;
        nvfuser_index_t i17;
        i17 = i10 + i16;
        Array<float, 5, 1> T5;
        #pragma unroll
        for(nvfuser_index_t i18 = 0; i18 < 5; ++i18) {
          nvfuser_index_t i19;
          i19 = i12 + (T1.alloc_stride[5LL] * i18);
          Array<float, 1, 1> T4;
          T4[0] = 0.000000000e+00f;
          #pragma unroll
          for(nvfuser_index_t i20 = 0; i20 < 7; ++i20) {
            nvfuser_index_t i21;
            i21 = i20 + nvfuser_zero;
            Array<float, 1, 1> T33;
            T33[0] = 0;
            T33[0]
               = T1[(i19 + (T1.alloc_stride[6LL] * i21))];
            Array<float, 1, 1> T32;
            T32[0] = 0;
            T32[0]
               = T0[(i14 + (T0.alloc_stride[6LL] * i21))];
            Array<float, 1, 1> T3;
            T3[0]
              = T32[0]
              * T33[0];
            T4[0]
              = T4[0]
              + T3[0];
          }
          T5[i18]
             = T4[0];
        }
        Array<float, 1, 1> T6;
        T6[0] = NEG_INFINITY;
        #pragma unroll
        for(nvfuser_index_t i22 = 0; i22 < 5; ++i22) {
          T6[0] = fmax(
            T6[0],
            T5[i22]);
        }
        Array<float, 1, 1> T7;
        T7[0]
           = T6[0];
        T9[i15] = ((i11 > 0) ? (T8[i15]) : NEG_INFINITY);
        T8[i15] = fmax(
          T9[i15],
          (T7[0]));
        Array<float, 1, 1> T14;
        T14[0]
          = T9[i15]
          - T8[i15];
        Array<float, 1, 1> T15;
        T15[0]
           = expf(T14[0]);
        Array<float, 1, 1> T12;
        T12[0] = 0.000000000e+00f;
        Array<float, 9, 1> T24;
        #pragma unroll
        for(nvfuser_index_t i23 = 0; i23 < 9; ++i23) {
          T24[i23] = 0.000000000e+00f;
        }
        #pragma unroll
        for(nvfuser_index_t i24 = 0; i24 < 5; ++i24) {
          nvfuser_index_t i25;
          i25 = i9 + (T2.alloc_stride[5LL] * i24);
          Array<float, 1, 1> T10;
          T10[0]
            = T5[i24]
            - T7[0];
          Array<float, 1, 1> T11;
          T11[0]
             = expf(T10[0]);
          T12[0]
            = T12[0]
            + T11[0];
          #pragma unroll
          for(nvfuser_index_t i23 = 0; i23 < 9; ++i23) {
            Array<float, 1, 1> T34;
            T34[0] = 0;
            T34[0]
               = T2[(i25 + (T2.alloc_stride[7LL] * (i23 + nvfuser_zero)))];
            Array<float, 1, 1> T23;
            T23[0]
              = T11[0]
              * T34[0];
            T24[i23]
              = T24[i23]
              + T23[0];
          }
        }
        Array<float, 1, 1> T13;
        T13[0]
           = T12[0];
        Array<float, 1, 1> T16;
        T16[0]
          = T7[0]
          - T8[i15];
        Array<float, 1, 1> T17;
        T17[0]
           = expf(T16[0]);
        Array<float, 1, 1> T18;
        T18[0]
          = T17[0]
          * T13[0];
        T20[i15] = ((i11 > 0) ? (T19[i15]) : 0.000000000e+00f);
        T19[i15]
          = (T15[0] * T20[i15])
          + (T18[0]);
        Array<float, 1, 1> T21;
        T21[0]
          = T20[i15]
          / T19[i15];
        Array<float, 1, 1> T22;
        T22[0]
          = T21[0]
          * T15[0];
        Array<float, 1, 1> T26;
        T26[0]
          = T17[0]
          / T19[i15];
        #pragma unroll
        for(nvfuser_index_t i5 = 0; i5 < 9; ++i5) {
          nvfuser_index_t i26;
          i26 = i17 + i5;
          Array<float, 1, 1> T25;
          T25[0]
             = T24[i5];
          Array<float, 1, 1> T27;
          T27[0]
            = T26[0]
            * T25[0];
          T28[i26]
            = (T22[0] * ((i11 > 0) ? (T28[i26]) : 0.000000000e+00f))
            + (T27[0]);
          Array<float, 1, 1> T29;
          T29[0]
             = T28[i26];
          T30[(i16 + i5)] = rhs(
            T30[(i16 + i5)],
            T29[0]);
        }
      }
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  #pragma unroll
  for(nvfuser_index_t i27 = 0; i27 < 3; ++i27) {
    nvfuser_index_t i28;
    i28 = 9 * i27;
    nvfuser_index_t i29;
    i29 = i2 + i28;
    #pragma unroll
    for(nvfuser_index_t i30 = 0; i30 < 9; ++i30) {
      Array<float, 1, 1> T35;
      T35[0]
         = T30[(i28 + i30)];
      T31[(i29 + (i30 + nvfuser_zero))]
         = T35[0];
    }
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
}

The goal here is to approximate the original FlashAttention pattern:
image

I am not sure if this is correct yet (it might be too far inlined in the main loop). But it does seem to compute something without requiring multiple top-level loops so that is encouraging...

Note that the original algorithm calls for caching l, m, and O to gmem between outer loop iterations, which I don't currently know how best to do. I suppose we should just set their MemoryTypes and cache them before using in downstream expressions. Avoiding redundant gmem use for the output O is a challenge here but might be enabled if we can get ScanOp to properly return a reduction tensor (I've had trouble with scheduling/lowering such an approach so far but the interface is there already).

naoyam added a commit that referenced this pull request Jun 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants