diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index 86c8d362f4c..d1d29e6062f 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -490,14 +490,12 @@ namespace { // Select first warp of threads along TIDx axis and then use ptx::elect_sync // TODO If TIDx is known at compile-time, generate custom mask. -Val* createElectSyncPredicate(bool use_first_warp = true) { - Val* warp_size = IrBuilder::create(32L, PrimDataType::UInt64); - Val* full_mask_val = IrBuilder::create(0xFFFFFFFF, PrimDataType::UInt32); - Val* elect_sync_val = IrBuilder::create(PrimDataType::Bool); - IrBuilder::create( - UnaryOpType::ElectSync, elect_sync_val, full_mask_val); +Val* createElectSyncPredicate( + bool use_first_warp = true, + bool is_warp_collective = false) { // If TIDx is used for both computation and TMA load, we should select a // thread from the last warp along TIDx. + Val* warp_size = IrBuilder::create(32L, PrimDataType::UInt64); auto select_warp = use_first_warp ? IrBuilder::ltExpr( NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size) @@ -506,6 +504,18 @@ Val* createElectSyncPredicate(bool use_first_warp = true) { IrBuilder::addExpr( NamedScalar::getParallelDim(ParallelType::TIDx), IrBuilder::create(-32L, PrimDataType::Index))); + + // Short-Circuit: TMA Store is a warp-collective, so ElectSync is not + // necessary. + if (is_warp_collective) { + return select_warp; + } + + // Create ElectSync Predicate to pick any thread in the warp + Val* full_mask_val = IrBuilder::create(0xFFFFFFFF, PrimDataType::UInt32); + Val* elect_sync_val = IrBuilder::create(PrimDataType::Bool); + IrBuilder::create( + UnaryOpType::ElectSync, elect_sync_val, full_mask_val); return SimplifyingIrBuilder::logicalAndExpr(elect_sync_val, select_warp); } @@ -537,14 +547,18 @@ Val* createElectSyncPredicate(kir::Predicate* pred) { } // short-circuit: Expect ParallelType::TIDx to have at least one warp. + bool is_tma_store = ir_utils::isCpAsyncBulkStore(pred->expr()); if (tidx_paralleltype_dim->isConstScalar() && tidx_paralleltype_dim->evaluate().as() < 32) { - Val* zero = IrBuilder::create(0L, PrimDataType::UInt64); - return IrBuilder::eqExpr( - NamedScalar::getParallelIndex(ParallelType::TIDx), zero); + if (is_tma_store) { + return pred->fusion()->trueVal(); + } else { + Val* zero = IrBuilder::create(0L, PrimDataType::UInt64); + return IrBuilder::eqExpr( + NamedScalar::getParallelIndex(ParallelType::TIDx), zero); + } } - - return createElectSyncPredicate(); + return createElectSyncPredicate(/*use_first_warp=*/true, is_tma_store); } Val* createSingleExpressionElectSync( @@ -559,9 +573,12 @@ Val* createSingleExpressionElectSync( TensorView* out_tv = ir_utils::getTvOutput(pred->expr()); Val* zero = IrBuilder::create(0L, PrimDataType::UInt64); - const auto& pdim_map = GpuLower::current()->parallelDimensionMap(); + + const ParallelDimensionMap& pdim_map = + GpuLower::current()->parallelDimensionMap(); auto pred_map = ParallelizedDomainPredicate::getPredicateMap(pred->expr(), loops); + Val* parallel_dom_pred = GpuLower::current()->kernel()->trueVal(); for (auto pt : {ParallelType::TIDx, ParallelType::TIDy, ParallelType::TIDz}) { // short-circuit: parallelDim is not used by CTA @@ -574,10 +591,10 @@ Val* createSingleExpressionElectSync( // exists. auto pred_info_it = pred_map.find(pt); if (pred_info_it != pred_map.end()) { - const auto& pred_info = pred_info_it->second; - auto tid_pred = pred_info.getPredicate(); - parallel_dom_pred = - SimplifyingIrBuilder::logicalAndExpr(parallel_dom_pred, tid_pred); + const ParallelizedDomainPredicate::PredicateInfo& pred_info = + pred_info_it->second; + parallel_dom_pred = SimplifyingIrBuilder::logicalAndExpr( + parallel_dom_pred, pred_info.getPredicate()); } // Case 2: ParallelDim is used by CTA but not the TMA/Blackwell MMA @@ -621,7 +638,8 @@ Val* createMultipleExpressionElectSync( NVF_ERROR(pred->expr() == nullptr); Val* zero = IrBuilder::create(0L, PrimDataType::UInt64); - const auto& pdim_map = GpuLower::current()->parallelDimensionMap(); + const ParallelDimensionMap& pdim_map = + GpuLower::current()->parallelDimensionMap(); // Determine if warp specialized tma load expression. ParallelType async_warp_on = ParallelType::Serial; @@ -649,7 +667,7 @@ Val* createMultipleExpressionElectSync( Val* conditional = async_warp_on == ParallelType::TIDx ? pred->fusion()->trueVal() : createElectSyncPredicate(); - for (auto pt : {ParallelType::TIDy, ParallelType::TIDz}) { + for (ParallelType pt : {ParallelType::TIDy, ParallelType::TIDz}) { if (pdim_map.has(pt) && async_warp_on != pt) { conditional = SimplifyingIrBuilder::logicalAndExpr( conditional, @@ -663,7 +681,7 @@ Val* createMultipleExpressionElectSync( // we can use the first warp, otherwise should use the last warp. bool use_first_warp = async_warp_on != ParallelType::TIDx; Val* conditional = createElectSyncPredicate(use_first_warp); - for (auto pt : {ParallelType::TIDy, ParallelType::TIDz}) { + for (ParallelType pt : {ParallelType::TIDy, ParallelType::TIDz}) { if (!pdim_map.has(pt)) { continue; } diff --git a/tests/cpp/test_memory.cpp b/tests/cpp/test_memory.cpp index 57d400e1f1d..f65a48c7f31 100644 --- a/tests/cpp/test_memory.cpp +++ b/tests/cpp/test_memory.cpp @@ -235,8 +235,14 @@ class XorFinder : private kir::IrVisitor { class TMAPredicateChecker : private kir::IrVisitor { int64_t num_threads_; int64_t cta_threads_; - TMAPredicateChecker(int64_t num_threads, int64_t cta_threads) - : num_threads_(num_threads), cta_threads_(cta_threads) {} + bool is_tma_store_; + TMAPredicateChecker( + int64_t num_threads, + int64_t cta_threads, + bool is_tma_store) + : num_threads_(num_threads), + cta_threads_(cta_threads), + is_tma_store_(is_tma_store) {} kir::Predicate* pred_ = nullptr; @@ -269,8 +275,31 @@ class TMAPredicateChecker : private kir::IrVisitor { ASSERT_NE(pred_, nullptr); auto cond = pred_->value(); ASSERT_NE(cond, nullptr); + + // Handle TMA Store first + if (is_tma_store_) { + if (cta_threads_ <= 32) { + EXPECT_TRUE(cond->isTrue()); + } else { + auto def = dynamic_cast(cond->definition()); + ASSERT_TRUE(def != nullptr); + EXPECT_TRUE(def->getBinaryOpType() == BinaryOpType::LT); + auto lhs = dynamic_cast(def->lhs()); + auto rhs = def->rhs(); + ASSERT_TRUE(lhs != nullptr); + ASSERT_TRUE(rhs != nullptr); + EXPECT_TRUE(lhs->isThreadIdx()); + EXPECT_TRUE(rhs->isConstInt()); + EXPECT_EQ(rhs->value(), 32); + } + return; + } + + // Then, handle TMA Load if (num_threads_ == 0) { EXPECT_TRUE(cond->isTrue()); + } else if (is_tma_store_ && cta_threads_ <= 32) { + EXPECT_TRUE(cond->isTrue()); } else if (num_threads_ == 1 && cta_threads_ > 32) { auto def = dynamic_cast(cond->definition()); ASSERT_TRUE(def != nullptr); @@ -324,8 +353,9 @@ class TMAPredicateChecker : private kir::IrVisitor { static void checkPredicate( kir::Kernel* kernel, int64_t num_threads, - int64_t cta_threads = -1) { - TMAPredicateChecker checker(num_threads, cta_threads); + int64_t cta_threads = -1, + bool is_tma_store = false) { + TMAPredicateChecker checker(num_threads, cta_threads, is_tma_store); checker.handle(kernel->topLevelExprs()); } }; @@ -611,7 +641,10 @@ TEST_P(TMASimpleLdstTest, Store) { EXPECT_EQ(TMADimChecker::getDim(ke.compiledKernel()->kernel()), dim); TMAPredicateChecker::checkPredicate( - ke.compiledKernel()->kernel(), 1, ke.lastLaunchParams().nThreads()); + ke.compiledKernel()->kernel(), + 1, + ke.lastLaunchParams().nThreads(), + /*is_tma_store=*/true); ASSERT_EQ( XorFinder::findXor(ke.compiledKernel()->kernel()), (swizzle != MmaInputSmemSwizzle::None)); @@ -2482,7 +2515,11 @@ TEST_F(TMADocTest, Figure14d) { ke.compile(&fusion, {t0}, {}, matmul_cparams); EXPECT_EQ(TMADimChecker::getDim(ke.compiledKernel()->kernel()), 2); - TMAPredicateChecker::checkPredicate(ke.compiledKernel()->kernel(), 1); + TMAPredicateChecker::checkPredicate( + ke.compiledKernel()->kernel(), + 1, + ke.lastLaunchParams().nThreads(), + /*is_tma_store=*/true); auto cg_outputs = ke.run({t0}); testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__); @@ -2565,7 +2602,10 @@ TEST_F(TMADocTest, Figure14e) { EXPECT_EQ(TMADimChecker::getDim(ke.compiledKernel()->kernel()), 2); TMAPredicateChecker::checkPredicate( - ke.compiledKernel()->kernel(), 1, ke.lastLaunchParams().nThreads()); + ke.compiledKernel()->kernel(), + 1, + ke.lastLaunchParams().nThreads(), + /*is_tma_store=*/true); } TEST_F(TMADocTest, Figure15a) {