Skip to content

warp specializied tma persistent kernel, step-4c, add iter grouped warp reduction #4315

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 39 commits into
base: main
Choose a base branch
from

Conversation

liqiangxl
Copy link
Collaborator

@liqiangxl liqiangxl commented Apr 25, 2025

This PR follows #4288. It is step-4cd of implementing warp specializied tma persistent kernel described in the design doc.
Changes:
(1) Add iteration grouped warp reduction, where broadcast is also fused with reduction, step-4c.
(2) Vectorized load of cached input with inner bcast, step-4d.

Details of iteration grouped warp reduction:

  • Input: Each thread has U elements with typte T , where U is the iteration dim unroll factor.

  • Algorithm:

    1. packedWarpReduce(), where elements of type T are packed into uint64_t to minimize shuffle instructions during warp reduction.
    2. Head thread in each warp, write warp reduction results to shared memory and sync threads.
    3. All threads read from shared memory and reduce across warps. Since this function is always used with 4 or 8 warps, this serial all reduce is efficient.
  • Output: All threads have a copy of the reduction results.

Fusion IR changes
Using
NVFUSER_DUMP=fusion_ir,cuda_to_file ./test_nvfuser --gtest_filter=TmaWarpSpecializedTest.SimpleFusion/ws_1_dtype___bfloat_batch_1056_hidden_4096 2>&1 |tee 1.log as an example.
The major change is the grouped reduction, e.g. T5_l_float[iblockIdx.y72{132}, iS71{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iG70{2}, rthreadIdx.x68{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p] ca_pos( 2 ) produce_pos( 2 )

CUDA code changes
(1) grouped reduction, it uses static CTA shape and fused with broadcast.

      warp::iterGroupedStaticWarpAllReduce<false, false, 2, 128>(T5.array, T20.array, [](float &a, float b) { a = a + b; }, static_cast<float*>(shared_mem), ((nvfuser_index_t)threadIdx.x));

(2) vectorized load of the inner broadcast Tv (only when it exists, e.g. RMS Norm Bwd)

      Array<float, 2, 2> T35;
      T35.set(float(0));
      if ((b26 && b62)) {
        loadGlobalToLocal<float, /*vec_size=*/2, /*is_volatile=*/false, CacheOp::Streaming>(&T35[0],  &T2[(i16 + i60)]);
      }

It is transformed as:

T35_l_float[iblockIdx.y213{132}, iS212{4}, iV211{2}, bthreadIdx.x339{1}_p, bS338{4}, bS337{8}] ca_pos( 2 )
 logical domain : (iS89{1056}, bS90{1})
 contiguity: t n
  Split: bS90{1} by factor 8 -> bS336{1}, bS337{8}
  Split: iS89{1056} by factor 2 -> iS210{528}, iV211{2}
  Split: iS210{528} by factor 132 -> iS212{4}, iblockIdx.y213{132}
  Outer split: bS336{1} by factor 4 -> bS338{4}, bthreadIdx.x339{1}_p
 loop domain : (iblockIdx.y213{132}, iS212{4}, iV211{2}, bthreadIdx.x339{1}_p, bS338{4}, bS337{8})

@liqiangxl
Copy link
Collaborator Author

!test

Copy link

github-actions bot commented Apr 25, 2025

Review updated until commit ba0b6d6

Description

  • Added iteration grouped warp reduction with fused broadcast.

  • Vectorized load of cached input with inner broadcast.

  • Updated kernel generation to use LaunchParams instead of num_threads_per_cta.

  • Enhanced CudaKernelGenerator to handle grouped warp reduction.


Changes walkthrough 📝

Relevant files
Enhancement
10 files
codegen.cpp
Updated kernel generation to use LaunchParams and added grouped warp
reduction.
+69/-20 
fused_reduction.cpp
Updated FusionInspector to handle grouped Id reductions. 
+10/-6   
compiled_kernel.cpp
Updated compile method to use LaunchParams.                           
+3/-3     
executor.cpp
Updated compile method to use LaunchParams.                           
+5/-4     
normalization_inner_outer_tma_ws.cpp
Added heuristics for iteration unroll factor and vectorization.
+70/-8   
reduction_utils.cpp
Updated clearUnrollVectorizationAddGroupReduction to handle grouped
reductions.
+13/-1   
test_combined_inner_outer_reduction.cpp
Updated test parameters to include contiguity and warp specialization.
+24/-12 
warp.cu
Added packedWarpReduce and iterGroupedStaticWarpAllReduce functions.
+88/-0   
codegen.h
Updated generateCudaKernel to use LaunchParams.                   
+1/-1     
compiled_kernel.h
Updated compile method to use LaunchParams.                           
+1/-1     

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Performance Concern

The new genGroupedWarpReduction function and related changes should be thoroughly tested for performance improvements. Ensure that the new grouped warp reduction is faster than the previous method.

void genGroupedWarpReduction(
    const int num_grouped_iterations,
    kir::TensorIndex* output,
    kir::TensorIndex* input,
    const Val* init,
    BinaryOpType reduction_op_type,
    kir::Predicate* read_pred) {
  ArgumentBuilder func_args;
  func_args.arg(genVariableNameConvertAlignedArray(output));
  func_args.arg(genVariableNameConvertAlignedArray(input));
  func_args.arg(genReductionOp(reduction_op_type, output->dtype()));
  func_args.arg(genStaticCast(genPtrType(output->dtype()), "shared_mem"));

  ArgumentBuilder template_args;
  func_args.arg(genInline(NamedScalar::getParallelIndex(ParallelType::TIDx)));
  template_args.arg(kernel_->getWarpPaddedParallelInfo().is_tidx_single_warp);
  template_args.arg(isAligned());
  template_args.arg(num_grouped_iterations);
  template_args.arg(lparams_.bdimx());
  indent() << genCall(
                  "warp::iterGroupedStaticWarpAllReduce",
                  template_args,
                  func_args)
           << ";\n";
}
void handle(const GroupedReductionOp* grouped_rop) final {
  const auto num_grouped_iterations =
      getGroupedLoopIndexConcreteIntSets().size();

  const auto num_grouped_exprs = grouped_rop->numHorizontallyGroupedExprs();

  // special version where only iteration is grouped.
  // used for outer reduction with vectorized iteration domain.
  if (num_grouped_iterations > 1 && num_grouped_exprs == 1) {
    const auto output = grouped_rop->output(0)->as<kir::TensorIndex>();
    const auto input = grouped_rop->input(0)->as<kir::TensorIndex>();
    const auto op_type = grouped_rop->getReductionOpType(0);
    const auto domain = output->view()->domain();
    const bool has_block_reduce = domain->hasBlockReduction();
    const bool has_grid_reduce = domain->hasGridReduction();
    NVF_ERROR(
        !has_grid_reduce, "IterGroupedGridReduction not implemented yet");
    NVF_ERROR(
        has_block_reduce,
        "To use IterGroupedBlockReduction, must have block reduce!");
    if (auto reduction_ids =
            ir_utils::getMaybeWarpReductionDim(output, input)) {
      NVF_ERROR(
          lparams_.bdimx() % 128 == 0,
          "iterGroupedStaticWarpAllReduce() requires bdimx % 128 == 0.");
      NVF_ERROR(
          grouped_rop->isAllreduce(),
          "iterGroupedStaticWarpAllReduce should be used for allreduce.");
      NVF_ERROR(
          reduction_ids.value().first &&
              reduction_ids.value().first->getParallelType() ==
                  ParallelType::TIDx &&
              reduction_ids.value().second == nullptr,
          "Grouped warp reduction is only supported for TIDx reduction with no second dimension.");
      return genGroupedWarpReduction(
          (int)num_grouped_iterations,
          output,
          input,
          grouped_rop->initVal(0),
          op_type,
          grouped_rop->predicate());
    } else {
      return genIterGroupedBlockReduction(
          (int)num_grouped_iterations,
          output,
          input,
          grouped_rop->initVal(0),
          op_type,
          grouped_rop->predicate(),
          grouped_rop->writePredicate());
    }
  }
Code Complexity

The logic for determining the iteration unroll factor and vectorization factor has become more complex. Ensure that this logic is correct and that it does not introduce any regressions.

rparams->circular_buffer_options = circular_buffer_options;

// TODO: This is a heuristic, need to be tuned.
// Iteration unroll factor, limited by:
// (1) heuristic selection
// (2) max possible due to smem limitation
// (3) Predicate of 1D TMA load requires iteration dim divisible by unroll
//     factor.
// (4) Round down to power of 2, since we need vectorized access in
//     smem reduction and loading of inner broadcast tv.
iter_remaining = scheduler_utils::safeDiv(iter_remaining, n_stages);
int64_t heu_iter_unroll = std::min(2L, iter_remaining);
int64_t max_iter_unroll = max_n_copies / n_stages;
int64_t iter_unroll_factor = std::min(heu_iter_unroll, max_iter_unroll);
iter_unroll_factor = scheduler_utils::lastPow2(iter_unroll_factor);
while (outer_dim_numel % iter_unroll_factor) {
  iter_unroll_factor /= 2;
}
rparams->unroll_factor_iter_dom = iter_unroll_factor;
Test Coverage

The new test cases for TmaWarpSpecializedTest should cover a wide range of scenarios, including edge cases. Ensure that the tests are comprehensive and that they validate the correctness of the new grouped warp reduction.

      scheduleAndRun(&fusion, SchedulerType::InnerOuterPersistent, {t0});
  auto persistent_params = cg_results.heuristic_params->as<ReductionParams>();
  ASSERT_FALSE(persistent_params->project_persistent_buffers);
  testValidate(&fusion_copy, cg_results.outputs, {t0}, __LINE__, __FILE__);
}

// contig, enable WarpSpecializedNormalization, dtype, dim0, dim1
using TmaWarpSpecializedParams =
    std::tuple<bool, bool, DataType, int64_t, int64_t>;
class TmaWarpSpecializedTest
    : public NVFuserFixtureParamTest<TmaWarpSpecializedParams> {
 public:
  void SetUp() override {
    opt_guard_ = std::make_unique<EnableOptionsGuard>();
    if (std::get<1>(GetParam())) {
      EnableOptionsGuard::getCurOptions().set(
          EnableOption::WarpSpecializedNormalization);
    } else {
      EnableOptionsGuard::getCurOptions().unset(
          EnableOption::WarpSpecializedNormalization);
    }
    NVFuserTest::SetUp();
  }

 protected:
  // This keeps the guard alive until all TmaWarpSpecializedTests are done.
  std::unique_ptr<EnableOptionsGuard> opt_guard_;
};

TEST_P(TmaWarpSpecializedTest, SimpleFusion) {
  NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
  auto [contig, _, dtype, dim0, dim1] = GetParam();
  if (!contig) {
    GTEST_SKIP() << "TMA load requires contig inner domain.";
  }
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  auto tv0 = makeContigTensor(2, dtype);
  auto tv1 = makeContigTensor(2, dtype);
  fusion->addInput(tv0);
  fusion->addInput(tv1);
  tv0 = maybeCastOp(DataType::Float, tv0);
  tv1 = maybeCastOp(DataType::Float, tv1);
  auto tv2 = add(tv0, tv1);
  auto tv3 = sum(tv2, {1});
  auto tv4 = broadcast(tv3, {false, true});
  auto tv5 = add(tv2, tv4);
  auto tv6 = sum(tv1, {0});
  tv5 = maybeCastOp(dtype, tv5);
  fusion->addOutput(tv5);
  fusion->addOutput(tv6);
  auto fusion_copy = *fusion;

  auto options =
      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
  at::Tensor t0 = at::randn({dim0, dim1}, options);
  at::Tensor t1 = at::randn({dim0, dim1}, options);

  FusionExecutorCache executor_cache(std::move(fusion));
  auto cg_outputs = executor_cache.runFusionWithInputs({t0, t1});
  auto runtime = executor_cache.getMostRecentKernelRuntime();
  EXPECT_THAT(
      runtime->fusionSegments()->groups(),
      UnorderedElementsAre(HeuristicIs(SchedulerType::InnerOuterPersistent)));
  testValidate(&fusion_copy, cg_outputs, {t0, t1}, __LINE__, __FILE__);
}

TEST_P(TmaWarpSpecializedTest, RMSNormBwd) {
  NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
  auto [contig, _, dtype, dim0, dim1] = GetParam();
  std::vector<int64_t> norm_shape{dim1};

  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  auto grad_out = makeContigTensor(2, dtype);
  auto input = makeContigTensor(2, dtype);
  auto rstd = contig ? makeContigConcreteTensor({dim0, 1})
                     : makeConcreteTensor({dim0, 1});
  auto weight = makeContigTensor(1, dtype);
  fusion->addInput(grad_out);
  fusion->addInput(input);
  fusion->addInput(rstd);
  fusion->addInput(weight);

  grad_out = maybeCastOp(DataType::Float, grad_out);
  input = maybeCastOp(DataType::Float, input);
  weight = maybeCastOp(DataType::Float, weight);
  auto grads = rms_norm_backward(
      grad_out, input, norm_shape, rstd, weight, {true, true});
  grads.grad_input = maybeCastOp(dtype, grads.grad_input);
  grads.grad_weight = maybeCastOp(dtype, grads.grad_weight);
  fusion->addOutput(grads.grad_input);
  fusion->addOutput(grads.grad_weight);

  auto fusion_copy = *fusion;
  auto options =
      at::TensorOptions().dtype(data_type_to_aten(dtype)).device(at::kCUDA, 0);
  std::vector<int64_t> shape{dim0, dim1};
  at::Tensor aten_grad_out = at::randn(shape, options);
  at::Tensor aten_input = at::randn(shape, options);
  at::Tensor aten_weight = at::randn(norm_shape, options);
  const float kEps = 1e-6;
  auto pow2 = at::pow(aten_input.to(at::kFloat), 2);
  auto sum = at::sum(pow2, -1, true);
  auto var = at::mul(sum, 1.0 / dim1);
  auto aten_rstd = at::pow(at::add(var, kEps), -0.5);

  FusionExecutorCache executor_cache(std::move(fusion));
  KernelArgumentHolder args = {
      aten_grad_out, aten_input, aten_rstd, aten_weight};
  auto cg_outputs = executor_cache.runFusionWithInputs(args);
  auto runtime = executor_cache.getMostRecentKernelRuntime();
  EXPECT_THAT(
      runtime->fusionSegments()->groups(),
      UnorderedElementsAre(HeuristicIs(SchedulerType::InnerOuterPersistent)));
  testValidate(
      &fusion_copy,
      cg_outputs,
      {aten_grad_out, aten_input, aten_rstd, aten_weight},
      __LINE__,
      __FILE__);
}
auto TmaWarpSpecializedTestParams() {
  std::vector<TmaWarpSpecializedParams> values;
  // Use 8 * SMs as the outer dimension to ensure divisible split by unroll
  // factor (1 or 2) and SM count.
  int64_t dim0 =
      8 * at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
  for (int64_t dim1 = 1024; dim1 <= 8192; dim1 += 1024) {
    for (auto dtype : {DataType::Float, DataType::BFloat16}) {
      for (bool warp_specialized : {true, false}) {
        for (bool contig : {true, false}) {
          if (!warp_specialized && !contig) {
            // Don't need to test non-contiguous version when warp
            // specialization is not used.
            continue;
          }
          values.emplace_back(contig, warp_specialized, dtype, dim0, dim1);
        }
      }
    }
  }
  return testing::ValuesIn(values);
}
INSTANTIATE_TEST_SUITE_P(
    ,
    TmaWarpSpecializedTest,
    TmaWarpSpecializedTestParams(),
    [](const testing::TestParamInfo<TmaWarpSpecializedParams>& info)
        -> std::string {
      std::stringstream ss;
      ss << "contig_" << std::get<0>(info.param);
      ss << "_ws_" << std::get<1>(info.param);
      ss << "_dtype_" << std::get<2>(info.param);
      ss << "_batch_" << std::get<3>(info.param);
      ss << "_hidden_" << std::get<4>(info.param);
      return sanitizeTestName(ss.str());
    });
} // namespace nvfuser

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test --dev

@liqiangxl liqiangxl force-pushed the llu/ws_tma_4c_grouped branch from 5a9216c to bcc6d91 Compare April 29, 2025 12:48
@liqiangxl
Copy link
Collaborator Author

!test

@liqiangxl
Copy link
Collaborator Author

!test --dev

@liqiangxl
Copy link
Collaborator Author

Fusion IR

Inputs:
  T0_g___bfloat[iS138{132}, iS137{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS136{2}, iS1{i2}]
  T1_g___bfloat[iS114{132}, iS113{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS112{2}, iS21{i2}]
Outputs:
  T9_g___bfloat[iblockIdx.y154{132}, iS153{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iUR152{2}, ithreadIdx.x208{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS207{8}, iV206{4}] ca_pos( 2 ) produce_pos( 2 )
  T8_g_float[iblockIdx.y241{( ceilDiv(( ceilDiv(i2, 4) ), 128) )}, ithreadIdx.x242{128}, iV240{4}] ca_pos( 2 )

%kernel {
T14_s___bfloat[iblockIdx.y134{132}, iS133{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iUR132{2}, iB39{i2}] ca_pos( 2 )
   = CpAsyncBulk( T0_g___bfloat[iS138{132}, iS137{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS136{2}, iS1{i2}] )
T18_l___bfloat[iblockIdx.y130{132}, iS129{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS128{2}, ithreadIdx.x192{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS191{8}, iV190{4}] ca_pos( 2 ) produce_pos( 2 )
   = Set( T14_s___bfloat[iblockIdx.y134{132}, iS133{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iUR132{2}, iB39{i2}] ca_pos( 2 ), cache_op=Streaming )
T10_l_float[iblockIdx.y126{132}, iS125{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS124{2}, ithreadIdx.x188{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS187{8}, iS186{4}] ca_pos( 6 ) produce_pos( 2 )
   = __bfloat2float(T18_l___bfloat[iblockIdx.y130{132}, iS129{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS128{2}, ithreadIdx.x192{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS191{8}, iV190{4}] ca_pos( 2 ) produce_pos( 2 ));
T15_s___bfloat[iblockIdx.y110{132}, iS109{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iUR108{2}, iB41{i2}] ca_pos( 2 )
   = CpAsyncBulk( T1_g___bfloat[iS114{132}, iS113{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS112{2}, iS21{i2}] )
T19_l___bfloat[iblockIdx.y106{132}, iS105{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS104{2}, ithreadIdx.x176{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS175{8}, iV174{4}] ca_pos( 2 ) produce_pos( 2 )
   = Set( T15_s___bfloat[iblockIdx.y110{132}, iS109{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iUR108{2}, iB41{i2}] ca_pos( 2 ), cache_op=Streaming )
T11_l_float[iblockIdx.y118{132}, iS117{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS116{2}, ithreadIdx.x180{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS179{8}, iS178{4}] ca_pos( 6 ) produce_pos( 2 )
   = __bfloat2float(T19_l___bfloat[iblockIdx.y106{132}, iS105{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS104{2}, ithreadIdx.x176{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS175{8}, iV174{4}] ca_pos( 2 ) produce_pos( 2 ));
T12_l_float[iblockIdx.y122{132}, iS121{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS120{2}, ithreadIdx.x184{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS183{8}, iS182{4}] ca_pos( 6 ) produce_pos( 6 )
   = T10_l_float[iblockIdx.y126{132}, iS125{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS124{2}, ithreadIdx.x188{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS187{8}, iS186{4}] ca_pos( 6 ) produce_pos( 2 )
   + T11_l_float[iblockIdx.y118{132}, iS117{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS116{2}, ithreadIdx.x180{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS179{8}, iS178{4}] ca_pos( 6 ) produce_pos( 2 );
T2_l_float[iblockIdx.y158{132}, iS157{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS156{2}, ithreadIdx.x216{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS215{8}, iS214{4}] ca_pos( 6 ) produce_pos( 2 )
   = __bfloat2float(T18_l___bfloat[iblockIdx.y130{132}, iS129{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS128{2}, ithreadIdx.x192{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS191{8}, iV190{4}] ca_pos( 2 ) produce_pos( 2 ));
T3_l_float[iblockIdx.y102{132}, iS101{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS100{2}, ithreadIdx.x172{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS171{8}, iS170{4}] ca_pos( 6 ) produce_pos( 2 )
   = __bfloat2float(T19_l___bfloat[iblockIdx.y106{132}, iS105{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS104{2}, ithreadIdx.x176{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS175{8}, iV174{4}] ca_pos( 2 ) produce_pos( 2 ));
T4_l_float[iblockIdx.y98{132}, iS97{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS96{2}, ithreadIdx.x168{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS167{8}, iS166{4}] ca_pos( 6 ) produce_pos( 6 )
   = T2_l_float[iblockIdx.y158{132}, iS157{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS156{2}, ithreadIdx.x216{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS215{8}, iS214{4}] ca_pos( 6 ) produce_pos( 2 )
   + T3_l_float[iblockIdx.y102{132}, iS101{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS100{2}, ithreadIdx.x172{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS171{8}, iS170{4}] ca_pos( 6 ) produce_pos( 2 );
T20_l_float[iblockIdx.y62{132}, iS61{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS60{2}, ithreadIdx.x66{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}rf_p, rS65{8}rf, rS64{4}rf] ca_pos( 2 ) produce_pos( 6 )
   = reduction( T4_l_float[iblockIdx.y98{132}, iS97{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS96{2}, ithreadIdx.x168{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS167{8}, iS166{4}] ca_pos( 6 ) produce_pos( 6 ), op = add, initial value = float(0), allreduce = false )
T5_l_float[iblockIdx.y72{132}, iS71{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iG70{2}, rthreadIdx.x68{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p] ca_pos( 2 ) produce_pos( 2 )
   = reduction( T20_l_float[iblockIdx.y62{132}, iS61{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS60{2}, ithreadIdx.x66{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}rf_p, rS65{8}rf, rS64{4}rf] ca_pos( 2 ) produce_pos( 6 ), op = add, initial value = float(0), allreduce = false )
T6_l_float[iblockIdx.y146{132}, iS145{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS144{2}, bthreadIdx.x200{1}_p, bS199{8}, bS198{4}] ca_pos( 3 ) produce_pos( 2 )
   = broadcast( T5_l_float[iblockIdx.y72{132}, iS71{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iG70{2}, rthreadIdx.x68{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p] ca_pos( 2 ) produce_pos( 2 ), flags = {false, true} )
T7_l_float[iblockIdx.y142{132}, iS141{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS140{2}, ithreadIdx.x196{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS195{8}, iS194{4}] ca_pos( 6 ) produce_pos( 6 )
   = T12_l_float[iblockIdx.y122{132}, iS121{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS120{2}, ithreadIdx.x184{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS183{8}, iS182{4}] ca_pos( 6 ) produce_pos( 6 )
   + T6_l_float[iblockIdx.y146{132}, iS145{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS144{2}, bthreadIdx.x200{1}_p, bS199{8}, bS198{4}] ca_pos( 3 ) produce_pos( 2 );
T16_l___bfloat[iblockIdx.y150{132}, iS149{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS148{2}, ithreadIdx.x204{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS203{8}, iS202{4}] ca_pos( 2 ) produce_pos( 6 )
   = __float2bfloat(T7_l_float[iblockIdx.y142{132}, iS141{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS140{2}, ithreadIdx.x196{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS195{8}, iS194{4}] ca_pos( 6 ) produce_pos( 6 ));
T9_g___bfloat[iblockIdx.y154{132}, iS153{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iUR152{2}, ithreadIdx.x208{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS207{8}, iV206{4}] ca_pos( 2 ) produce_pos( 2 )
   = Set( T16_l___bfloat[iblockIdx.y150{132}, iS149{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS148{2}, ithreadIdx.x204{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS203{8}, iS202{4}] ca_pos( 2 ) produce_pos( 6 ), cache_op=Streaming )
T22_l_float[iblockIdx.y82{132}rf, rS81{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}rf, rS80{2}rf, ithreadIdx.x212{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS211{8}, iS210{4}] ca_pos( 1 ) produce_pos( 6 )
   = reduction( T3_l_float[iblockIdx.y102{132}, iS101{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS100{2}, ithreadIdx.x172{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS171{8}, iS170{4}] ca_pos( 6 ) produce_pos( 2 ), op = add, initial value = float(0), allreduce = false )
T21_g_float[iblockIdx.y85{132}, ithreadIdx.x232{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS231{8}, iV230{4}] produce_pos( 1 )
   = Set( T22_l_float[iblockIdx.y82{132}rf, rS81{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}rf, rS80{2}rf, ithreadIdx.x212{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS211{8}, iS210{4}] ca_pos( 1 ) produce_pos( 6 ), cache_op=Streaming )
T23_l_float[iS233{132}, iS234{1}, iblockIdx.y237{( ceilDiv(( ceilDiv(i2, 4) ), 128) )}, ithreadIdx.x238{128}, iV236{4}] ca_pos( 4 )
   = Set( T21_g_float[iblockIdx.y85{132}, ithreadIdx.x232{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS231{8}, iV230{4}] produce_pos( 1 ), cache_op=Streaming )
T17_l_float[rS89{132}, rS90{1}, iblockIdx.y93{( ceilDiv(( ceilDiv(i2, 4) ), 128) )}, ithreadIdx.x94{128}, iS92{4}] produce_pos( 4 )
   = reduction( T23_l_float[iS233{132}, iS234{1}, iblockIdx.y237{( ceilDiv(( ceilDiv(i2, 4) ), 128) )}, ithreadIdx.x238{128}, iV236{4}] ca_pos( 4 ), op = add, initial value = float(0), allreduce = false )
T8_g_float[iblockIdx.y241{( ceilDiv(( ceilDiv(i2, 4) ), 128) )}, ithreadIdx.x242{128}, iV240{4}] ca_pos( 2 )
   = Set( T17_l_float[rS89{132}, rS90{1}, iblockIdx.y93{( ceilDiv(( ceilDiv(i2, 4) ), 128) )}, ithreadIdx.x94{128}, iS92{4}] produce_pos( 4 ), cache_op=Streaming )

TransformPrinter : 
T0_g___bfloat[iS138{132}, iS137{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS136{2}, iS1{i2}]
 logical domain : (iS0{i0}, iS1{i2})
 contiguity: t t
  Split: iS0{i0} by factor 2 -> iS135{( ceilDiv(i0, 2) )}, iS136{2}
  Split: iS135{( ceilDiv(i0, 2) )} by factor 132 -> iS137{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS138{132}
 loop domain : (iS138{132}, iS137{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS136{2}, iS1{i2})
T14_s___bfloat[iblockIdx.y134{132}, iS133{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iUR132{2}, iB39{i2}] ca_pos( 2 )
 logical domain : (iS38{i0}, iB39{i2})
 contiguity: t t
  Split: iS38{i0} by factor 2 -> iS131{( ceilDiv(i0, 2) )}, iUR132{2}
  Split: iS131{( ceilDiv(i0, 2) )} by factor 132 -> iS133{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iblockIdx.y134{132}
 loop domain : (iblockIdx.y134{132}, iS133{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iUR132{2}, iB39{i2})
T18_l___bfloat[iblockIdx.y130{132}, iS129{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS128{2}, ithreadIdx.x192{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS191{8}, iV190{4}] ca_pos( 2 ) produce_pos( 2 )
 logical domain : (iS45{i0}, iS46{i2})
 contiguity: t t
  Split: iS46{i2} by factor 4 -> iS189{( ceilDiv(i2, 4) )}, iV190{4}
  Split: iS45{i0} by factor 2 -> iS127{( ceilDiv(i0, 2) )}, iS128{2}
  Split: iS127{( ceilDiv(i0, 2) )} by factor 132 -> iS129{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iblockIdx.y130{132}
  Outer split: iS189{( ceilDiv(i2, 4) )} by factor 8 -> iS191{8}, ithreadIdx.x192{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p
 loop domain : (iblockIdx.y130{132}, iS129{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS128{2}, ithreadIdx.x192{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS191{8}, iV190{4})
T10_l_float[iblockIdx.y126{132}, iS125{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS124{2}, ithreadIdx.x188{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS187{8}, iS186{4}] ca_pos( 6 ) produce_pos( 2 )
 logical domain : (iS30{i0}, iS31{i2})
 contiguity: t t
  Split: iS31{i2} by factor 4 -> iS185{( ceilDiv(i2, 4) )}, iS186{4}
  Split: iS30{i0} by factor 2 -> iS123{( ceilDiv(i0, 2) )}, iS124{2}
  Split: iS123{( ceilDiv(i0, 2) )} by factor 132 -> iS125{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iblockIdx.y126{132}
  Outer split: iS185{( ceilDiv(i2, 4) )} by factor 8 -> iS187{8}, ithreadIdx.x188{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p
 loop domain : (iblockIdx.y126{132}, iS125{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS124{2}, ithreadIdx.x188{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS187{8}, iS186{4})
T1_g___bfloat[iS114{132}, iS113{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS112{2}, iS21{i2}]
 logical domain : (iS20{i0}, iS21{i2})
 contiguity: t t
  Split: iS20{i0} by factor 2 -> iS111{( ceilDiv(i0, 2) )}, iS112{2}
  Split: iS111{( ceilDiv(i0, 2) )} by factor 132 -> iS113{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS114{132}
 loop domain : (iS114{132}, iS113{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS112{2}, iS21{i2})
T15_s___bfloat[iblockIdx.y110{132}, iS109{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iUR108{2}, iB41{i2}] ca_pos( 2 )
 logical domain : (iS40{i0}, iB41{i2})
 contiguity: t t
  Split: iS40{i0} by factor 2 -> iS107{( ceilDiv(i0, 2) )}, iUR108{2}
  Split: iS107{( ceilDiv(i0, 2) )} by factor 132 -> iS109{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iblockIdx.y110{132}
 loop domain : (iblockIdx.y110{132}, iS109{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iUR108{2}, iB41{i2})
T19_l___bfloat[iblockIdx.y106{132}, iS105{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS104{2}, ithreadIdx.x176{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS175{8}, iV174{4}] ca_pos( 2 ) produce_pos( 2 )
 logical domain : (iS47{i0}, iS48{i2})
 contiguity: t t
  Split: iS48{i2} by factor 4 -> iS173{( ceilDiv(i2, 4) )}, iV174{4}
  Split: iS47{i0} by factor 2 -> iS103{( ceilDiv(i0, 2) )}, iS104{2}
  Split: iS103{( ceilDiv(i0, 2) )} by factor 132 -> iS105{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iblockIdx.y106{132}
  Outer split: iS173{( ceilDiv(i2, 4) )} by factor 8 -> iS175{8}, ithreadIdx.x176{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p
 loop domain : (iblockIdx.y106{132}, iS105{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS104{2}, ithreadIdx.x176{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS175{8}, iV174{4})
T11_l_float[iblockIdx.y118{132}, iS117{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS116{2}, ithreadIdx.x180{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS179{8}, iS178{4}] ca_pos( 6 ) produce_pos( 2 )
 logical domain : (iS32{i0}, iS33{i2})
 contiguity: t t
  Split: iS33{i2} by factor 4 -> iS177{( ceilDiv(i2, 4) )}, iS178{4}
  Split: iS32{i0} by factor 2 -> iS115{( ceilDiv(i0, 2) )}, iS116{2}
  Split: iS115{( ceilDiv(i0, 2) )} by factor 132 -> iS117{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iblockIdx.y118{132}
  Outer split: iS177{( ceilDiv(i2, 4) )} by factor 8 -> iS179{8}, ithreadIdx.x180{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p
 loop domain : (iblockIdx.y118{132}, iS117{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS116{2}, ithreadIdx.x180{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS179{8}, iS178{4})
T12_l_float[iblockIdx.y122{132}, iS121{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS120{2}, ithreadIdx.x184{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS183{8}, iS182{4}] ca_pos( 6 ) produce_pos( 6 )
 logical domain : (iS34{i0}, iS35{i2})
 contiguity: t t
  Split: iS35{i2} by factor 4 -> iS181{( ceilDiv(i2, 4) )}, iS182{4}
  Split: iS34{i0} by factor 2 -> iS119{( ceilDiv(i0, 2) )}, iS120{2}
  Split: iS119{( ceilDiv(i0, 2) )} by factor 132 -> iS121{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iblockIdx.y122{132}
  Outer split: iS181{( ceilDiv(i2, 4) )} by factor 8 -> iS183{8}, ithreadIdx.x184{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p
 loop domain : (iblockIdx.y122{132}, iS121{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS120{2}, ithreadIdx.x184{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS183{8}, iS182{4})
T2_l_float[iblockIdx.y158{132}, iS157{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS156{2}, ithreadIdx.x216{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS215{8}, iS214{4}] ca_pos( 6 ) produce_pos( 2 )
 logical domain : (iS4{i0}, iS5{i2})
 contiguity: t t
  Split: iS5{i2} by factor 4 -> iS213{( ceilDiv(i2, 4) )}, iS214{4}
  Split: iS4{i0} by factor 2 -> iS155{( ceilDiv(i0, 2) )}, iS156{2}
  Split: iS155{( ceilDiv(i0, 2) )} by factor 132 -> iS157{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iblockIdx.y158{132}
  Outer split: iS213{( ceilDiv(i2, 4) )} by factor 8 -> iS215{8}, ithreadIdx.x216{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p
 loop domain : (iblockIdx.y158{132}, iS157{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS156{2}, ithreadIdx.x216{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS215{8}, iS214{4})
T3_l_float[iblockIdx.y102{132}, iS101{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS100{2}, ithreadIdx.x172{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS171{8}, iS170{4}] ca_pos( 6 ) produce_pos( 2 )
 logical domain : (iS22{i0}, iS23{i2})
 contiguity: t t
  Split: iS23{i2} by factor 4 -> iS169{( ceilDiv(i2, 4) )}, iS170{4}
  Split: iS22{i0} by factor 2 -> iS99{( ceilDiv(i0, 2) )}, iS100{2}
  Split: iS99{( ceilDiv(i0, 2) )} by factor 132 -> iS101{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iblockIdx.y102{132}
  Outer split: iS169{( ceilDiv(i2, 4) )} by factor 8 -> iS171{8}, ithreadIdx.x172{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p
 loop domain : (iblockIdx.y102{132}, iS101{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS100{2}, ithreadIdx.x172{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS171{8}, iS170{4})
T4_l_float[iblockIdx.y98{132}, iS97{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS96{2}, ithreadIdx.x168{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS167{8}, iS166{4}] ca_pos( 6 ) produce_pos( 6 )
 logical domain : (iS8{i0}, iS9{i2})
 contiguity: t t
  Split: iS9{i2} by factor 4 -> iS165{( ceilDiv(i2, 4) )}, iS166{4}
  Split: iS8{i0} by factor 2 -> iS95{( ceilDiv(i0, 2) )}, iS96{2}
  Split: iS95{( ceilDiv(i0, 2) )} by factor 132 -> iS97{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iblockIdx.y98{132}
  Outer split: iS165{( ceilDiv(i2, 4) )} by factor 8 -> iS167{8}, ithreadIdx.x168{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p
 loop domain : (iblockIdx.y98{132}, iS97{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS96{2}, ithreadIdx.x168{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS167{8}, iS166{4})
T20_l_float[iblockIdx.y62{132}, iS61{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS60{2}, ithreadIdx.x66{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}rf_p, rS65{8}rf, rS64{4}rf] ca_pos( 2 ) produce_pos( 6 )
 root domain : (iS57{i0}, rS58{i2}rf)
  Split: rS58{i2}rf by factor 4 -> rS63{( ceilDiv(i2, 4) )}rf, rS64{4}rf
  Outer split: rS63{( ceilDiv(i2, 4) )}rf by factor 8 -> rS65{8}rf, ithreadIdx.x66{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}rf_p
 logical domain : (iS57{i0}, rS65{8}rf, ithreadIdx.x66{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}rf_p, rS64{4}rf)
 contiguity: t n t n
  Split: iS57{i0} by factor 2 -> iS59{( ceilDiv(i0, 2) )}, iS60{2}
  Split: rS58{i2}rf by factor 4 -> rS63{( ceilDiv(i2, 4) )}rf, rS64{4}rf
  Split: iS59{( ceilDiv(i0, 2) )} by factor 132 -> iS61{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iblockIdx.y62{132}
  Outer split: rS63{( ceilDiv(i2, 4) )}rf by factor 8 -> rS65{8}rf, ithreadIdx.x66{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}rf_p
 loop domain : (iblockIdx.y62{132}, iS61{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS60{2}, ithreadIdx.x66{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}rf_p, rS65{8}rf, rS64{4}rf)
T5_l_float[iblockIdx.y72{132}, iS71{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iG70{2}, rthreadIdx.x68{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p] ca_pos( 2 ) produce_pos( 2 )
 logical domain : (iS67{i0}, rthreadIdx.x68{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p)
 contiguity: t n
  Split: iS67{i0} by factor 2 -> iS69{( ceilDiv(i0, 2) )}, iG70{2}
  Split: iS69{( ceilDiv(i0, 2) )} by factor 132 -> iS71{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iblockIdx.y72{132}
 loop domain : (iblockIdx.y72{132}, iS71{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iG70{2}, rthreadIdx.x68{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p)
T6_l_float[iblockIdx.y146{132}, iS145{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS144{2}, bthreadIdx.x200{1}_p, bS199{8}, bS198{4}] ca_pos( 3 ) produce_pos( 2 )
 logical domain : (iS12{i0}, bS13{1})
 allocation domain : (iS12{i0}, bS13{1})
 contiguity: t n
  Split: bS13{1} by factor 4 -> bS197{1}, bS198{4}
  Split: iS12{i0} by factor 2 -> iS143{( ceilDiv(i0, 2) )}, iS144{2}
  Split: iS143{( ceilDiv(i0, 2) )} by factor 132 -> iS145{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iblockIdx.y146{132}
  Outer split: bS197{1} by factor 8 -> bS199{8}, bthreadIdx.x200{1}_p
 loop domain : (iblockIdx.y146{132}, iS145{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS144{2}, bthreadIdx.x200{1}_p, bS199{8}, bS198{4})
T7_l_float[iblockIdx.y142{132}, iS141{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS140{2}, ithreadIdx.x196{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS195{8}, iS194{4}] ca_pos( 6 ) produce_pos( 6 )
 logical domain : (iS14{i0}, iS15{i2})
 contiguity: t t
  Split: iS15{i2} by factor 4 -> iS193{( ceilDiv(i2, 4) )}, iS194{4}
  Split: iS14{i0} by factor 2 -> iS139{( ceilDiv(i0, 2) )}, iS140{2}
  Split: iS139{( ceilDiv(i0, 2) )} by factor 132 -> iS141{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iblockIdx.y142{132}
  Outer split: iS193{( ceilDiv(i2, 4) )} by factor 8 -> iS195{8}, ithreadIdx.x196{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p
 loop domain : (iblockIdx.y142{132}, iS141{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS140{2}, ithreadIdx.x196{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS195{8}, iS194{4})
T16_l___bfloat[iblockIdx.y150{132}, iS149{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS148{2}, ithreadIdx.x204{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS203{8}, iS202{4}] ca_pos( 2 ) produce_pos( 6 )
 logical domain : (iS18{i0}, iS19{i2})
 contiguity: t t
  Split: iS19{i2} by factor 4 -> iS201{( ceilDiv(i2, 4) )}, iS202{4}
  Split: iS18{i0} by factor 2 -> iS147{( ceilDiv(i0, 2) )}, iS148{2}
  Split: iS147{( ceilDiv(i0, 2) )} by factor 132 -> iS149{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iblockIdx.y150{132}
  Outer split: iS201{( ceilDiv(i2, 4) )} by factor 8 -> iS203{8}, ithreadIdx.x204{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p
 loop domain : (iblockIdx.y150{132}, iS149{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iS148{2}, ithreadIdx.x204{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS203{8}, iS202{4})
T9_g___bfloat[iblockIdx.y154{132}, iS153{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iUR152{2}, ithreadIdx.x208{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS207{8}, iV206{4}] ca_pos( 2 ) produce_pos( 2 )
 logical domain : (iS42{i0}, iS43{i2})
 contiguity: t t
  Split: iS43{i2} by factor 4 -> iS205{( ceilDiv(i2, 4) )}, iV206{4}
  Split: iS42{i0} by factor 2 -> iS151{( ceilDiv(i0, 2) )}, iUR152{2}
  Split: iS151{( ceilDiv(i0, 2) )} by factor 132 -> iS153{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iblockIdx.y154{132}
  Outer split: iS205{( ceilDiv(i2, 4) )} by factor 8 -> iS207{8}, ithreadIdx.x208{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p
 loop domain : (iblockIdx.y154{132}, iS153{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}, iUR152{2}, ithreadIdx.x208{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS207{8}, iV206{4})
T22_l_float[iblockIdx.y82{132}rf, rS81{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}rf, rS80{2}rf, ithreadIdx.x212{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS211{8}, iS210{4}] ca_pos( 1 ) produce_pos( 6 )
 root domain : (rS77{i0}rf, iS78{i2})
  Split: rS77{i0}rf by factor 2 -> rS79{( ceilDiv(i0, 2) )}rf, rS80{2}rf
  Split: rS79{( ceilDiv(i0, 2) )}rf by factor 132 -> rS81{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}rf, iblockIdx.y82{132}rf
 logical domain : (rS81{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}rf, iblockIdx.y82{132}rf, rS80{2}rf, iS78{i2})
 contiguity: n t n t
  Split: iS78{i2} by factor 4 -> iS209{( ceilDiv(i2, 4) )}, iS210{4}
  Split: rS77{i0}rf by factor 2 -> rS79{( ceilDiv(i0, 2) )}rf, rS80{2}rf
  Split: rS79{( ceilDiv(i0, 2) )}rf by factor 132 -> rS81{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}rf, iblockIdx.y82{132}rf
  Outer split: iS209{( ceilDiv(i2, 4) )} by factor 8 -> iS211{8}, ithreadIdx.x212{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p
 loop domain : (iblockIdx.y82{132}rf, rS81{( ceilDiv(( ceilDiv(i0, 2) ), 132) )}rf, rS80{2}rf, ithreadIdx.x212{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS211{8}, iS210{4})
T21_g_float[iblockIdx.y85{132}, ithreadIdx.x232{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS231{8}, iV230{4}] produce_pos( 1 )
 logical domain : (iblockIdx.y85{132}, iS86{i2})
 contiguity: t t
  Split: iS86{i2} by factor 4 -> iS229{( ceilDiv(i2, 4) )}, iV230{4}
  Outer split: iS229{( ceilDiv(i2, 4) )} by factor 8 -> iS231{8}, ithreadIdx.x232{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p
 loop domain : (iblockIdx.y85{132}, ithreadIdx.x232{( ceilDiv(( ceilDiv(i2, 4) ), 8) )}_p, iS231{8}, iV230{4})
T23_l_float[iS233{132}, iS234{1}, iblockIdx.y237{( ceilDiv(( ceilDiv(i2, 4) ), 128) )}, ithreadIdx.x238{128}, iV236{4}] ca_pos( 4 )
 logical domain : (iS87{132}, iS88{i2})
 contiguity: t t
  Split: iS88{i2} by factor 4 -> iS235{( ceilDiv(i2, 4) )}, iV236{4}
  Split: iS87{132} by factor 1 -> iS233{132}, iS234{1}
  Split: iS235{( ceilDiv(i2, 4) )} by factor 128 -> iblockIdx.y237{( ceilDiv(( ceilDiv(i2, 4) ), 128) )}, ithreadIdx.x238{128}
 loop domain : (iS233{132}, iS234{1}, iblockIdx.y237{( ceilDiv(( ceilDiv(i2, 4) ), 128) )}, ithreadIdx.x238{128}, iV236{4})
T17_l_float[rS89{132}, rS90{1}, iblockIdx.y93{( ceilDiv(( ceilDiv(i2, 4) ), 128) )}, ithreadIdx.x94{128}, iS92{4}] produce_pos( 4 )
 logical domain : (rS83{132}, iS84{i2})
 contiguity: n t
  Split: iS84{i2} by factor 4 -> iS91{( ceilDiv(i2, 4) )}, iS92{4}
  Split: rS83{132} by factor 1 -> rS89{132}, rS90{1}
  Split: iS91{( ceilDiv(i2, 4) )} by factor 128 -> iblockIdx.y93{( ceilDiv(( ceilDiv(i2, 4) ), 128) )}, ithreadIdx.x94{128}
 loop domain : (rS89{132}, rS90{1}, iblockIdx.y93{( ceilDiv(( ceilDiv(i2, 4) ), 128) )}, ithreadIdx.x94{128}, iS92{4})
T8_g_float[iblockIdx.y241{( ceilDiv(( ceilDiv(i2, 4) ), 128) )}, ithreadIdx.x242{128}, iV240{4}] ca_pos( 2 )
 logical domain : (iS44{i2})
 contiguity: t
  Split: iS44{i2} by factor 4 -> iS239{( ceilDiv(i2, 4) )}, iV240{4}
  Split: iS239{( ceilDiv(i2, 4) )} by factor 128 -> iblockIdx.y241{( ceilDiv(( ceilDiv(i2, 4) ), 128) )}, ithreadIdx.x242{128}
 loop domain : (iblockIdx.y241{( ceilDiv(( ceilDiv(i2, 4) ), 128) )}, ithreadIdx.x242{128}, iV240{4})
} // %kernel

CUDA Kernel:

__syncthreads();
  if ((((nvfuser_index_t)threadIdx.y) >= 1)) {
    #pragma unroll 1
    for(nvfuser_index_t i40 = 0; i40 < i0; ++i40) {
      nvfuser_index_t i41;
      i41 = i7 * i40;
      __bfloat* ptr42;
      ptr42 = ptr6 + i41;
      nvfuser_index_t i43;
      i43 = i1 * (i40 % 2);
      uint32_t i44;
      i44 = i8 + i43;
      __bfloat* ptr45;
      ptr45 = ptr9 + i41;
      uint32_t i46;
      i46 = i10 + i43;
      if ((Hopper::electSync(4294967295U) && b19)) {
        mbarrier::waitParity(toSmem((&T25[((i40 % 2) + 2LL)])), (uint32_t)(((i40 / 2) % 2)));
        mbarrier::arriveExpectTX(toSmem((&T25[(i40 % 2)])), i2);
        #pragma unroll
        for(nvfuser_index_t i47 = 0; i47 < 2; ++i47) {
          Hopper::cpAsyncBulkG2S((Hopper::CpAsyncBulkG2SIndex{ (ptr42 + (T0.logical_size[1LL] * i47)), i4, toSmem((&T25[(i40 % 2)])) }), (i44 + (i3 * i47)));
        }
        mbarrier::arriveExpectTX(toSmem((&T25[(i40 % 2)])), i2);
        #pragma unroll
        for(nvfuser_index_t i48 = 0; i48 < 2; ++i48) {
          Hopper::cpAsyncBulkG2S((Hopper::CpAsyncBulkG2SIndex{ (ptr45 + (T0.logical_size[1LL] * i48)), i4, toSmem((&T25[(i40 % 2)])) }), (i46 + (i3 * i48)));
        }
      }
    }
    return;
  } else {
    #pragma unroll
    for(nvfuser_index_t i49 = 0; i49 < 2; ++i49) {
      mbarrier::arrive(toSmem((&T25[(i49 + 2LL)])));
    }
    #pragma unroll 1
    for(nvfuser_index_t i50 = 0; i50 < i0; ++i50) {
      nvfuser_index_t i51;
      i51 = i11 + (i3 * (i50 % 2));
      nvfuser_index_t i52;
      i52 = i15 + (i7 * i50);
      nvfuser_index_t i53;
      i53 = 264 * i50;
      nvfuser_index_t i54;
      i54 = i25 + i53;
      Array<__bfloat, 64, 4> T18;
      #pragma unroll
      for(nvfuser_index_t i55 = 0; i55 < 2; ++i55) {
        nvfuser_index_t i56;
        i56 = 32 * i55;
        #pragma unroll
        for(nvfuser_index_t i57 = 0; i57 < 8; ++i57) {
          T18.set(__bfloat(0));
        }
      }
      mbarrier::waitParity(toSmem((&T25[(i50 % 2)])), (uint32_t)(((i50 / 2) % 2)));
      #pragma unroll
      for(nvfuser_index_t i55 = 0; i55 < 2; ++i55) {
        nvfuser_index_t i58;
        i58 = i51 + (T0.logical_size[1LL] * i55);
        nvfuser_index_t i59;
        i59 = 32 * i55;
        bool b60;
        b60 = b23 && (i54 < (-i55));
        #pragma unroll
        for(nvfuser_index_t i57 = 0; i57 < 8; ++i57) {
          nvfuser_index_t i61;
          i61 = i14 * i57;
          if ((b60 && (i26 < (-i61)))) {
            loadGeneric<__bfloat, 4>( &T18[(i59 + (4 * i57))],  &T14[(i58 + i61)]);
          }
        }
      }
      Array<__bfloat, 64, 4> T19;
      #pragma unroll
      for(nvfuser_index_t i62 = 0; i62 < 2; ++i62) {
        nvfuser_index_t i63;
        i63 = 32 * i62;
        #pragma unroll
        for(nvfuser_index_t i64 = 0; i64 < 8; ++i64) {
          T19.set(__bfloat(0));
        }
      }
      #pragma unroll
      for(nvfuser_index_t i62 = 0; i62 < 2; ++i62) {
        nvfuser_index_t i65;
        i65 = i51 + (T0.logical_size[1LL] * i62);
        nvfuser_index_t i66;
        i66 = 32 * i62;
        bool b67;
        b67 = b23 && (i54 < (-i62));
        #pragma unroll
        for(nvfuser_index_t i64 = 0; i64 < 8; ++i64) {
          nvfuser_index_t i68;
          i68 = i14 * i64;
          if ((b67 && (i26 < (-i68)))) {
            loadGeneric<__bfloat, 4>( &T19[(i66 + (4 * i64))],  &T15[(i65 + i68)]);
          }
        }
      }
      mbarrier::arrive(toSmem((&T25[((i50 % 2) + 2LL)])));
      Array<float, 2, 1> T20;
      #pragma unroll
      for(nvfuser_index_t i69 = 0; i69 < 2; ++i69) {
        T20[i69] = 0.000000000e+00f;
      }
      #pragma unroll
      for(nvfuser_index_t i69 = 0; i69 < 2; ++i69) {
        nvfuser_index_t i70;
        i70 = 32 * i69;
        bool b71;
        b71 = b23 && (i54 < (-i69));
        #pragma unroll
        for(nvfuser_index_t i35 = 0; i35 < 8; ++i35) {
          nvfuser_index_t i72;
          i72 = 4 * i35;
          nvfuser_index_t i73;
          i73 = i70 + i72;
          bool b74;
          b74 = b71 && (i26 < (-(i14 * i35)));
          #pragma unroll
          for(nvfuser_index_t i37 = 0; i37 < 4; ++i37) {
            nvfuser_index_t i75;
            i75 = i73 + i37;
            Array<float, 1, 1> T3;
            T3[0]
               = __bfloat2float(T19[i75]);
            if (b74) {
              T22[(i72 + i37)]
                = T22[(i72 + i37)]
                + T3[0];
            }
            Array<float, 1, 1> T2;
            T2[0]
               = __bfloat2float(T18[i75]);
            Array<float, 1, 1> T4;
            T4[0]
              = T2[0]
              + T3[0];
            if (b74) {
              T20[i69]
                = T20[i69]
                + T4[0];
            }
          }
        }
      }
      Array<float, 2, 1> T5;
      #pragma unroll
      for(nvfuser_index_t i76 = 0; i76 < 2; ++i76) {
        T5[i76] = 0.000000000e+00f;
      }
      warp::iterGroupedStaticWarpAllReduce<false, false, 2, 128>(T5.array, T20.array, [](float &a, float b) { a = a + b; }, static_cast<float*>(shared_mem), ((nvfuser_index_t)threadIdx.x));
      // Alias Allocation - register
      auto& T16 = T19;
      #pragma unroll
      for(nvfuser_index_t i77 = 0; i77 < 2; ++i77) {
        nvfuser_index_t i78;
        i78 = 32 * i77;
        Array<float, 1, 1> T6;
        T6[0]
           = T5[i77];
        #pragma unroll
        for(nvfuser_index_t i79 = 0; i79 < 8; ++i79) {
          nvfuser_index_t i80;
          i80 = i78 + (4 * i79);
          #pragma unroll
          for(nvfuser_index_t i81 = 0; i81 < 4; ++i81) {
            nvfuser_index_t i82;
            i82 = i80 + i81;
            Array<float, 1, 1> T11;
            T11[0]
               = __bfloat2float(T19[i82]);
            Array<float, 1, 1> T10;
            T10[0]
               = __bfloat2float(T18[i82]);
            Array<float, 1, 1> T12;
            T12[0]
              = T10[0]
              + T11[0];
            Array<float, 1, 1> T7;
            T7[0]
              = T12[0]
              + T6[0];
            T16[i82]
               = __float2bfloat(T7[0]);
          }
        }
      }
      if ((b27 && ((i28 + i53) < T0.logical_size[0LL]))) {
        #pragma unroll
        for(nvfuser_index_t i83 = 0; i83 < 2; ++i83) {
          nvfuser_index_t i84;
          i84 = 32 * i83;
          nvfuser_index_t i85;
          i85 = i52 + (T0.logical_size[1LL] * i83);
          #pragma unroll
          for(nvfuser_index_t i86 = 0; i86 < 8; ++i86) {
            if (b29) {
              loadLocalToGlobal<__bfloat, /*vec_size=*/4, /*is_volatile=*/false>( &T9[(i85 + (i14 * i86))], &T16[(i84 + (4 * i86))]);
            }
          }
        }
      } else {
        #pragma unroll
        for(nvfuser_index_t i83 = 0; i83 < 2; ++i83) {
          nvfuser_index_t i87;
          i87 = 32 * i83;
          nvfuser_index_t i88;
          i88 = i52 + (T0.logical_size[1LL] * i83);
          bool b89;
          b89 = b30 && (i54 < (-i83));
          #pragma unroll
          for(nvfuser_index_t i86 = 0; i86 < 8; ++i86) {
            nvfuser_index_t i90;
            i90 = i14 * i86;
            if ((b89 && (i26 < (-i90)))) {
              loadLocalToGlobal<__bfloat, /*vec_size=*/4, /*is_volatile=*/false>( &T9[(i88 + i90)], &T16[(i87 + (4 * i86))]);
            }
          }
        }
      }
    }
  }

@liqiangxl
Copy link
Collaborator Author

!test

constexpr unsigned int align_size = sizeof(T) * N;
static_assert(align_size <= 16, "max allowed vect r/w is 16 bytes");
// [warp_idx, N]
// [w0r0, w0r1, w0r2, w0r3, w1r0, w1r1, w1r2, w1r3]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The smem access pattern is illustrated in the figures:
image

@liqiangxl liqiangxl marked this pull request as ready for review April 30, 2025 21:27
@liqiangxl liqiangxl requested a review from jjsjann123 April 30, 2025 21:27
for (auto cached_tv : cached_inputs) {
if (cached_tv->hasBroadcast() &&
is_redu_mapped_to_bcast(inner_reference_tv, cached_tv)) {
cached_tv->axis(2)->parallelize(ParallelType::Vectorize);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should check for contiguity flag as well as allocation domain for vectorization.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. Heuristic ensures the iteration dim is divisible by the unroll factor. Here, we only need to further confirm all the iteration domains are contiguous. Also extended tests to test case with non-contig input.


      auto can_vectorize = [](TensorView* redu_tv, TensorView* bcast_tv) {
        const auto& alloc_dom_1 = redu_tv->getMaybeAllocationDomain();
        const auto& alloc_dom_2 = bcast_tv->getMaybeAllocationDomain();
        if (alloc_dom_1.size() != alloc_dom_2.size()) {
          return false;
        }
        const auto& contiguity = bcast_tv->domain()->contiguity();
        for (int i = 0; i < (int)alloc_dom_1.size(); i++) {
          if (alloc_dom_1[i]->isReduction()) {
            break;
          }
          if (!contiguity[i].has_value() || !contiguity[i].value()) {
            return false;
          }
        }
        return true;
      };

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for patching this.

A follow up ask is, can we add a test covering this case, when the inner most dimension is not contiguous.

csrc/codegen.cpp Outdated
std::pair<IterDomain*, IterDomain*> reduction_dims,
bool is_all_reduce) {
NVF_ERROR(
is_all_reduce,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels strange to pass an arg and assert on it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved to before function call.

        NVF_ERROR(
            grouped_rop->isAllreduce(),
            "iterGroupedStaticWarpAllReduce should be used for allreduce.");
        return genGroupedWarpReduction(

csrc/codegen.cpp Outdated
func_args.arg(genStaticCast(genPtrType(output->dtype()), "shared_mem"));

ArgumentBuilder template_args;
if (reduction_dims.first->getParallelType() == ParallelType::TIDx &&
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick, assert or check reduction_dims.first != nullptr

csrc/codegen.cpp Outdated
<< ";\n";
} else {
NVF_THROW(
"Grouped warp reduction is only supported for TIDx reduction with no second dimension");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick, if we are just throwing here, might as well convert this to an assert to avoid branching.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. Moved to the out of the function.

        NVF_ERROR(
            lparams_.bdimx() % 128 == 0,
            "iterGroupedStaticWarpAllReduce() requires bdimx % 128 == 0.");
        NVF_ERROR(
            grouped_rop->isAllreduce(),
            "iterGroupedStaticWarpAllReduce should be used for allreduce.");
        NVF_ERROR(
            reduction_dims.first->getParallelType() == ParallelType::TIDx &&
                reduction_dims.second == nullptr,
            "Grouped warp reduction is only supported for TIDx reduction with no second dimension.");
        return genGroupedWarpReduction(
            (int)num_grouped_iterations,
            output,
            input,
            grouped_rop->initVal(0),
            op_type,
            grouped_rop->predicate());

runtime/warp.cu Outdated

// sizeof(T) * K = sizeof(uint64_t)
// require alginment of sizeof(T) * K to safely cast between T and uint64_t
// shfl uses uint64_t to reduce number of shuffles
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick question: since this is runtime file, I don't think we can do concepts there, can we?

Copy link
Collaborator Author

@liqiangxl liqiangxl May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

didn't check that, but // Array structure ensures data is aligned for safe casting to uint64_t. So we actually don't need to add extra checks.

template <typename scalar_t, int size, int align_size = 1>
struct alignas(sizeof(scalar_t) * align_size) Array {

@liqiangxl
Copy link
Collaborator Author

!test

const T inp_val[N],
Func reduction_op,
T* shared_mem,
uint32_t threadIdx_x,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dumb question, why are we passing threadIdx_x as arg, instead of just using threadIdx.x as-is here?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants