diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 80f424a9a31..462479b69fe 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -301,6 +301,39 @@ Val* ParallelDimensionMap::getNumComputeThreadsEachBlock() const { return num_threads; } +// For warp-specialization, the CTA is padded so the AsyncWarp contains 128 +// threads. This function maps the AsyncWarp CTA to a linear index from +// [0, 128). It is used to divide AsyncWarp into four independent warps. +Val* ParallelDimensionMap::getLinearThreadIndexAsync() const { + Val* index = GpuLower::current()->kernel()->zeroVal(); + Val* extent = GpuLower::current()->kernel()->oneVal(); + + for (auto pt : kParallelTypeTIDs) { + // For warp-specialization, an axis is padded so the AsyncWarp contains + // 128 threads. + Val* extent_for_pdim = getRawAsync(pt); + // short-circuit: extent_for_pdim is not used in kernel. + if (extent_for_pdim == nullptr) { + continue; + } + // short-circuit: extent_for_pdim is trivial. + if (extent_for_pdim->isConstScalar() && + extent_for_pdim->evaluate().as() == 1) { + continue; + } + Val* pt_index = NamedScalar::getParallelIndex(pt); + // Map the padded parallel index to [0, padded_value] range, so the linear + // index will be in range of [0, 128). + if (warp_specialized_types_.count(pt)) { + pt_index = SimplifyingIrBuilder::subExpr(pt_index, getRawCompute(pt)); + } + index = SimplifyingIrBuilder::addExpr( + index, SimplifyingIrBuilder::mulExpr(pt_index, extent)); + extent = SimplifyingIrBuilder::mulExpr(extent, extent_for_pdim); + } + return index; +} + int64_t ParallelDimensionMap::getWarpSpecializationPaddedVal( ParallelType pt) const { NVF_ERROR( @@ -315,6 +348,27 @@ int64_t ParallelDimensionMap::getWarpSpecializationPaddedVal( return ws_with_register_sharing_pad_val_.value(); } +bool ParallelDimensionMap::canUseElectSyncInAsyncWarp() const { + // short-circuit: skip if warp specialization is not enabled + if (warp_specialized_types_.empty()) { + return true; + } + // Currently only support one warp specialized axis + NVF_ERROR(warp_specialized_types_.size() == 1); + ParallelType ws_pt = *warp_specialized_types_.begin(); + + // Check that BlockDim.x >= 32 active threads in AsyncWarp + if (ws_pt != ParallelType::TIDx) { + return true; + } + + if (getWarpSpecializationPaddedVal(ws_pt) >= 32) { + return true; + } + + return false; +} + std::string ParallelDimensionMap::toString() const { std::stringstream ss; for (auto pt : kParallelTypeThreads) { diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index 2f88ab612aa..f3f40faff1c 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -73,6 +73,9 @@ class ParallelDimensionMap { //! buffer tensors. Val* getNumComputeThreadsEachBlock() const; + //! Assign linear index to each thread of CTA. Assume (TDZ, TDY, TDX) order. + Val* getLinearThreadIndexAsync() const; + //! Get if the kernel uses warp specialization bool hasWarpSpecialization() const { return !warp_specialized_types_.empty(); @@ -82,6 +85,10 @@ class ParallelDimensionMap { return dim_map_.count(pt) > 0; } + // If warp specialized on TIDx and padded value is less than 32 threads, then + // elect-sync cannot be used. + bool canUseElectSyncInAsyncWarp() const; + private: //! Get number of threads for ParallelType axis //! Not used: 1, Const: n, Dynamic: -1 diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index d1d29e6062f..aa27bdc7637 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -488,38 +488,63 @@ std::size_t UnswitchPredicateKeyHash::operator()( namespace { -// Select first warp of threads along TIDx axis and then use ptx::elect_sync +// Create elect-sync to pick a thread +Val* createElectSyncExpr() { + Val* full_mask_val = IrBuilder::create(0xFFFFFFFF, PrimDataType::UInt32); + Val* elect_sync_val = IrBuilder::create(PrimDataType::Bool); + IrBuilder::create( + UnaryOpType::ElectSync, elect_sync_val, full_mask_val); + return elect_sync_val; +} + +// Select first warp of threads along TIDx axis and use ptx::elect_sync if not +// warp collective. // TODO If TIDx is known at compile-time, generate custom mask. -Val* createElectSyncPredicate( - bool use_first_warp = true, - bool is_warp_collective = false) { - // If TIDx is used for both computation and TMA load, we should select a - // thread from the last warp along TIDx. +Val* selectFirstWarpElectSyncPredicate(bool is_warp_collective) { Val* warp_size = IrBuilder::create(32L, PrimDataType::UInt64); - auto select_warp = use_first_warp - ? IrBuilder::ltExpr( - NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size) - : IrBuilder::geExpr( - NamedScalar::getParallelIndex(ParallelType::TIDx), - IrBuilder::addExpr( - NamedScalar::getParallelDim(ParallelType::TIDx), - IrBuilder::create(-32L, PrimDataType::Index))); + Val* select_first_warp = IrBuilder::ltExpr( + NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size); // Short-Circuit: TMA Store is a warp-collective, so ElectSync is not // necessary. if (is_warp_collective) { - return select_warp; + return select_first_warp; } - // Create ElectSync Predicate to pick any thread in the warp - Val* full_mask_val = IrBuilder::create(0xFFFFFFFF, PrimDataType::UInt32); - Val* elect_sync_val = IrBuilder::create(PrimDataType::Bool); - IrBuilder::create( - UnaryOpType::ElectSync, elect_sync_val, full_mask_val); - return SimplifyingIrBuilder::logicalAndExpr(elect_sync_val, select_warp); + return SimplifyingIrBuilder::logicalAndExpr( + createElectSyncExpr(), select_first_warp); } -Val* createElectSyncPredicate(kir::Predicate* pred) { +// Get linear index for AsyncWarp Group. Then, select first warp. Finally, use +// ptx::elect_sync if not warp collective. +// TODO If TIDx is known at compile-time, generate custom mask. +Val* createElectSyncPredicateAsync() { + Val* zero = IrBuilder::create(0L, PrimDataType::UInt64); + Val* warp_size = IrBuilder::create(32L, PrimDataType::UInt64); + + const ParallelDimensionMap& pdim_map = + GpuLower::current()->parallelDimensionMap(); + Val* async_warp_thread_index = pdim_map.getLinearThreadIndexAsync(); + Val* warp_id = + SimplifyingIrBuilder::divExpr(async_warp_thread_index, warp_size); + // TODO Only select first warp now + Val* select_warp = SimplifyingIrBuilder::eqExpr(warp_id, zero); + + // Use elect-sync if available + if (pdim_map.canUseElectSyncInAsyncWarp()) { + return SimplifyingIrBuilder::logicalAndExpr( + select_warp, createElectSyncExpr()); + } + + // Warp Specialized ParallelType is ThreadIdx.x and it contains less than 32 + // threads, so manually select first thread in warp. + Val* thread_id = + SimplifyingIrBuilder::modExpr(async_warp_thread_index, warp_size); + Val* select_thread = SimplifyingIrBuilder::eqExpr(thread_id, zero); + return SimplifyingIrBuilder::logicalAndExpr(select_warp, select_thread); +} + +Val* createElectSyncPredicate(kir::Predicate* pred, bool is_async_warp) { NVF_ERROR(pred != nullptr); NVF_ERROR(pred->expr() != nullptr); @@ -558,7 +583,12 @@ Val* createElectSyncPredicate(kir::Predicate* pred) { NamedScalar::getParallelIndex(ParallelType::TIDx), zero); } } - return createElectSyncPredicate(/*use_first_warp=*/true, is_tma_store); + + NVF_ERROR(!(is_tma_store && is_async_warp)); + if (is_async_warp) { + return createElectSyncPredicateAsync(); + } + return selectFirstWarpElectSyncPredicate(is_tma_store); } Val* createSingleExpressionElectSync( @@ -579,6 +609,10 @@ Val* createSingleExpressionElectSync( auto pred_map = ParallelizedDomainPredicate::getPredicateMap(pred->expr(), loops); + bool is_async_warp = std::any_of(loops.begin(), loops.end(), [](ForLoop* fl) { + return fl->circularBufferLoopStage() == CircularBufferLoopStage::AsyncWarp; + }); + Val* parallel_dom_pred = GpuLower::current()->kernel()->trueVal(); for (auto pt : {ParallelType::TIDx, ParallelType::TIDy, ParallelType::TIDz}) { // short-circuit: parallelDim is not used by CTA @@ -607,7 +641,7 @@ Val* createSingleExpressionElectSync( if (pt == ParallelType::TIDx) { // Use createElectSyncPredicate for ParallelDim::TIDx. parallel_dom_pred = SimplifyingIrBuilder::logicalAndExpr( - parallel_dom_pred, createElectSyncPredicate(pred)); + parallel_dom_pred, createElectSyncPredicate(pred, is_async_warp)); } else { // Select first element of dimension for ParallelDim::TIDy and // ParallelDim::TIDz. @@ -648,7 +682,6 @@ Val* createMultipleExpressionElectSync( return fl->circularBufferLoopStage() == CircularBufferLoopStage::AsyncWarp; }); - bool is_register_sharing = false; if (async_warp_loop_it != loops.end()) { auto circular_buffer_type = std::get( GpuLower::current() @@ -656,17 +689,15 @@ Val* createMultipleExpressionElectSync( .getCircularBufferOptionsFor((*async_warp_loop_it)->iter_domain()) .type); async_warp_on = circular_buffer_type.on; - is_register_sharing = circular_buffer_type.num_registers.has_value(); } - // Short-circuit: register sharing is not used, don't need to pad a full warp - // group. If we are in a async warp, then the warp-dispatching IfThenElse - // already selects on `async_warp_on`, so we should not generate - // predicates for it here. - if (!is_register_sharing) { + // Short-circuit: If we are in a async warp, then the warp-dispatching + // IfThenElse already selects on `async_warp_on`, so we should not + // generate predicates for it here. + if (async_warp_loop_it == loops.end()) { Val* conditional = async_warp_on == ParallelType::TIDx ? pred->fusion()->trueVal() - : createElectSyncPredicate(); + : selectFirstWarpElectSyncPredicate(/*is_warp_collective=*/false); for (ParallelType pt : {ParallelType::TIDy, ParallelType::TIDz}) { if (pdim_map.has(pt) && async_warp_on != pt) { conditional = SimplifyingIrBuilder::logicalAndExpr( @@ -677,31 +708,7 @@ Val* createMultipleExpressionElectSync( return conditional; } - // If not specialized on TIDx, load branch has full size of bdimx, - // we can use the first warp, otherwise should use the last warp. - bool use_first_warp = async_warp_on != ParallelType::TIDx; - Val* conditional = createElectSyncPredicate(use_first_warp); - for (ParallelType pt : {ParallelType::TIDy, ParallelType::TIDz}) { - if (!pdim_map.has(pt)) { - continue; - } - if (async_warp_on != pt) { - // Not specialized on pt, use the first thread. - conditional = SimplifyingIrBuilder::logicalAndExpr( - conditional, - IrBuilder::eqExpr(NamedScalar::getParallelIndex(pt), zero)); - } else { - // Specialized on pt, use the last thread. - Val* raw = GpuLower::current()->parallelDimensionMap().get(async_warp_on); - conditional = SimplifyingIrBuilder::logicalAndExpr( - conditional, - IrBuilder::eqExpr( - NamedScalar::getParallelIndex(pt), - IrBuilder::subExpr( - raw, IrBuilder::create(1, DataType::Index)))); - } - } - return conditional; + return createElectSyncPredicateAsync(); } } // namespace