diff --git a/csrc/parallel_dimension_map.cpp b/csrc/parallel_dimension_map.cpp index 462479b69fe..69fd763033a 100644 --- a/csrc/parallel_dimension_map.cpp +++ b/csrc/parallel_dimension_map.cpp @@ -47,10 +47,11 @@ void ParallelDimensionMap::build(Fusion* fusion) { tv->circularBufferOptions().type)) { const auto& warp_specialized = std::get(tv->circularBufferOptions().type); - warp_specialized_types_.insert(warp_specialized.on); - if (warp_specialized.num_registers.has_value()) { - ws_with_register_sharing_pt_ = warp_specialized.on; - } + NVF_ERROR( + !warp_specialized_parallel_type_.has_value() || + warp_specialized_parallel_type_.value() == warp_specialized.on, + "Multiple warp specialized axis detected."); + warp_specialized_parallel_type_ = warp_specialized.on; } for (auto id : tv->domain()->allIDs()) { auto ptype = id->getParallelType(); @@ -165,31 +166,13 @@ int64_t ParallelDimensionMap::getThreadCountInDim(ParallelType pt) { void ParallelDimensionMap::adjustMappingsForWarpSpecialization() { // shortcut for case without register sharing - if (!ws_with_register_sharing_pt_.has_value()) { - for (auto pt : warp_specialized_types_) { - auto dim_it = dim_map_.find(pt); - if (dim_it == dim_map_.end()) { - dim_map_[pt] = IrBuilder::create(2, DataType::Index); - } else { - // Intentionally not using SimplifyingIrBuilder::addExpr here so that - // we still have access to the pointer to the original IR node. - // We need the pointer to the original IR node because we want - // getRawCompute to be callable in an environment without FusionGuard, - // that is, when the IR container is read-only. In such an environment, - // we can't create new IR nodes for (x - 1). By using - // IrBuilder::addExpr, we can always create IR nodes like addExpr(x, 1), - // and SimplifyingIrBuilder::addExpr in getRawCompute will be able to - // simplify find the x when we do addExpr(addExpr(x, 1) - 1). - dim_map_[pt] = IrBuilder::addExpr( - dim_it->second, dim_it->second->fusion()->oneVal()); - } - exact_types_.erase(pt); - } + if (!warp_specialized_parallel_type_.has_value()) { return; } + // Warp specialization with register sharing on parallel type pt // index = TIDx + TIDy * bdimx + TIDz * bdimx * bdimy - auto pt = ws_with_register_sharing_pt_.value(); + auto pt = warp_specialized_parallel_type_.value(); auto dim_it = dim_map_.find(pt); int64_t pad_n_threads = 0; int64_t after_pad = 0; @@ -241,7 +224,7 @@ void ParallelDimensionMap::adjustMappingsForWarpSpecialization() { } // Apply the pad - ws_with_register_sharing_pad_val_ = pad_n_threads; + warp_specialized_padding_value_ = pad_n_threads; auto off_set = IrBuilder::create(pad_n_threads, DataType::Index); auto current_val = dim_it == dim_map_.end() ? IrBuilder::create(1, DataType::Index) @@ -274,7 +257,7 @@ bool ParallelDimensionMap::isExact(ParallelType pt) const { Val* ParallelDimensionMap::getRawCompute(ParallelType pt) const { Val* raw = getRaw(pt); - if (warp_specialized_types_.count(pt)) { + if (isWarpSpecialized(pt)) { int64_t padded_val = getWarpSpecializationPaddedVal(pt); return SimplifyingIrBuilder::addExpr(raw, -padded_val); } @@ -282,7 +265,7 @@ Val* ParallelDimensionMap::getRawCompute(ParallelType pt) const { } Val* ParallelDimensionMap::getRawAsync(ParallelType pt) const { - if (warp_specialized_types_.count(pt)) { + if (isWarpSpecialized(pt)) { return IrBuilder::create( getWarpSpecializationPaddedVal(pt), DataType::Index); } @@ -324,7 +307,7 @@ Val* ParallelDimensionMap::getLinearThreadIndexAsync() const { Val* pt_index = NamedScalar::getParallelIndex(pt); // Map the padded parallel index to [0, padded_value] range, so the linear // index will be in range of [0, 128). - if (warp_specialized_types_.count(pt)) { + if (isWarpSpecialized(pt)) { pt_index = SimplifyingIrBuilder::subExpr(pt_index, getRawCompute(pt)); } index = SimplifyingIrBuilder::addExpr( @@ -336,26 +319,25 @@ Val* ParallelDimensionMap::getLinearThreadIndexAsync() const { int64_t ParallelDimensionMap::getWarpSpecializationPaddedVal( ParallelType pt) const { - NVF_ERROR( - warp_specialized_types_.contains(pt), "Can't find ParallelType: ", pt); - if (!ws_with_register_sharing_pt_.has_value()) { + NVF_ERROR(isWarpSpecialized(pt), "Can't find ParallelType: ", pt); + if (!warp_specialized_parallel_type_.has_value()) { return 1; } NVF_ERROR( - ws_with_register_sharing_pt_.value() == pt, + warp_specialized_parallel_type_.value() == pt, "Can't find padded val for: ", pt); - return ws_with_register_sharing_pad_val_.value(); + return warp_specialized_padding_value_.value(); } bool ParallelDimensionMap::canUseElectSyncInAsyncWarp() const { // short-circuit: skip if warp specialization is not enabled - if (warp_specialized_types_.empty()) { + if (!hasWarpSpecialization()) { return true; } // Currently only support one warp specialized axis - NVF_ERROR(warp_specialized_types_.size() == 1); - ParallelType ws_pt = *warp_specialized_types_.begin(); + NVF_ERROR(warp_specialized_parallel_type_.has_value()); + ParallelType ws_pt = warp_specialized_parallel_type_.value(); // Check that BlockDim.x >= 32 active threads in AsyncWarp if (ws_pt != ParallelType::TIDx) { diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index f3f40faff1c..778fe21636c 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -78,7 +78,12 @@ class ParallelDimensionMap { //! Get if the kernel uses warp specialization bool hasWarpSpecialization() const { - return !warp_specialized_types_.empty(); + return warp_specialized_parallel_type_.has_value(); + } + + //! Check if ParallelType is WarpSpecialized parallel type. + bool isWarpSpecialized(ParallelType pt) const { + return warp_specialized_parallel_type_.value_or(ParallelType::Serial) == pt; } bool has(ParallelType pt) const { @@ -111,12 +116,8 @@ class ParallelDimensionMap { //! exactly the same as extents of mapped domains. std::unordered_set exact_types_; - //! Set of parallel types that we are doing warp specialization on - std::unordered_set warp_specialized_types_; - - //! warp specialization with register sharing, keep track of - //! the parallel type and the threads padding - std::optional ws_with_register_sharing_pt_; - std::optional ws_with_register_sharing_pad_val_; + //! Keep track of warp specialized parallel type and padding value + std::optional warp_specialized_parallel_type_; + std::optional warp_specialized_padding_value_; }; } // namespace nvfuser diff --git a/tests/cpp/test_circular_buffering.cpp b/tests/cpp/test_circular_buffering.cpp index 05ecff64c12..c068f8ca4af 100644 --- a/tests/cpp/test_circular_buffering.cpp +++ b/tests/cpp/test_circular_buffering.cpp @@ -1096,12 +1096,22 @@ class TmaCircularBufferingTest NVFuserTest::SetUp(); } + bool testEnablesWarpSpecialization() { + return std::holds_alternative(circular_buffer_type); + } + bool testEnablesRegisterSharing() { return std::holds_alternative(circular_buffer_type) && std::get(circular_buffer_type) .num_registers.has_value(); } + bool testEnablesTIDx() { + return testEnablesWarpSpecialization() && + std::get(circular_buffer_type).on == + ParallelType::TIDx; + } + bool testEnablesRegisterSharingTIDx() { return testEnablesRegisterSharing() && std::get(circular_buffer_type).on == @@ -1776,8 +1786,8 @@ TEST_P(TmaCircularBufferingTest, OuterReduction) { TEST_P(TmaCircularBufferingTest, Persistent) { NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); - if (testEnablesRegisterSharing()) { - GTEST_SKIP() << "Bdimx is dynamic, register Sharing is disabled"; + if (testEnablesWarpSpecialization()) { + GTEST_SKIP() << "Bdimx is dynamic, Warp Specialization is disabled."; return; } @@ -1917,9 +1927,9 @@ TEST_P(TmaCircularBufferingTest, Persistent) { TEST_P(TmaCircularBufferingTest, Matmul) { NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); - if (testEnablesRegisterSharingTIDx()) { + if (testEnablesTIDx()) { GTEST_SKIP() - << "Register Sharing with TIDx used for both computation and load, requires TIDx to be a multiple of 128."; + << "Warp Specialization with TIDx used for both computation and load, requires TIDx to be a multiple of 128."; return; } @@ -2060,9 +2070,9 @@ TEST_P(TmaCircularBufferingTest, Matmul) { TEST_P(TmaCircularBufferingTest, MatmulWithBroadcastedInput) { NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); - if (testEnablesRegisterSharingTIDx()) { + if (testEnablesTIDx()) { GTEST_SKIP() - << "Register Sharing with TIDx used for both computation and load, requires TIDx to be a multiple of 128."; + << "Warp Specialization with TIDx used for both computation and load, requires TIDx to be a multiple of 128."; return; } diff --git a/tests/cpp/test_combined_inner_outer_reduction.cpp b/tests/cpp/test_combined_inner_outer_reduction.cpp index 984e794973d..4f3edc05e28 100644 --- a/tests/cpp/test_combined_inner_outer_reduction.cpp +++ b/tests/cpp/test_combined_inner_outer_reduction.cpp @@ -1071,10 +1071,13 @@ class TmaWarpSpecializedTest TEST_P(TmaWarpSpecializedTest, SimpleFusion) { NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); - auto [contig, _, dtype, dim0, dim1] = GetParam(); - if (!contig) { - GTEST_SKIP() << "TMA load requires contig inner domain."; + auto [contig, ws_enabled, dtype, dim0, dim1] = GetParam(); + + if (ws_enabled) { + GTEST_SKIP() << "Bdimx is dynamic, Warp Specialization is disabled."; + return; } + auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); auto tv0 = makeContigTensor(2, dtype); @@ -1109,7 +1112,13 @@ TEST_P(TmaWarpSpecializedTest, SimpleFusion) { TEST_P(TmaWarpSpecializedTest, RMSNormBwd) { NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0); - auto [contig, _, dtype, dim0, dim1] = GetParam(); + auto [contig, ws_enabled, dtype, dim0, dim1] = GetParam(); + + if (ws_enabled) { + GTEST_SKIP() << "Bdimx is dynamic, Warp Specialization is disabled."; + return; + } + std::vector norm_shape{dim1}; auto fusion = std::make_unique();