Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 19 additions & 37 deletions csrc/parallel_dimension_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,11 @@ void ParallelDimensionMap::build(Fusion* fusion) {
tv->circularBufferOptions().type)) {
const auto& warp_specialized =
std::get<WarpSpecialized>(tv->circularBufferOptions().type);
warp_specialized_types_.insert(warp_specialized.on);
if (warp_specialized.num_registers.has_value()) {
ws_with_register_sharing_pt_ = warp_specialized.on;
}
NVF_ERROR(
!warp_specialized_parallel_type_.has_value() ||
warp_specialized_parallel_type_.value() == warp_specialized.on,
"Multiple warp specialized axis detected.");
warp_specialized_parallel_type_ = warp_specialized.on;
}
for (auto id : tv->domain()->allIDs()) {
auto ptype = id->getParallelType();
Expand Down Expand Up @@ -165,31 +166,13 @@ int64_t ParallelDimensionMap::getThreadCountInDim(ParallelType pt) {

void ParallelDimensionMap::adjustMappingsForWarpSpecialization() {
// shortcut for case without register sharing
if (!ws_with_register_sharing_pt_.has_value()) {
for (auto pt : warp_specialized_types_) {
auto dim_it = dim_map_.find(pt);
if (dim_it == dim_map_.end()) {
dim_map_[pt] = IrBuilder::create<Val>(2, DataType::Index);
} else {
// Intentionally not using SimplifyingIrBuilder::addExpr here so that
// we still have access to the pointer to the original IR node.
// We need the pointer to the original IR node because we want
// getRawCompute to be callable in an environment without FusionGuard,
// that is, when the IR container is read-only. In such an environment,
// we can't create new IR nodes for (x - 1). By using
// IrBuilder::addExpr, we can always create IR nodes like addExpr(x, 1),
// and SimplifyingIrBuilder::addExpr in getRawCompute will be able to
// simplify find the x when we do addExpr(addExpr(x, 1) - 1).
dim_map_[pt] = IrBuilder::addExpr(
dim_it->second, dim_it->second->fusion()->oneVal());
}
exact_types_.erase(pt);
}
if (!warp_specialized_parallel_type_.has_value()) {
return;
}

// Warp specialization with register sharing on parallel type pt
// index = TIDx + TIDy * bdimx + TIDz * bdimx * bdimy
auto pt = ws_with_register_sharing_pt_.value();
auto pt = warp_specialized_parallel_type_.value();
auto dim_it = dim_map_.find(pt);
int64_t pad_n_threads = 0;
int64_t after_pad = 0;
Expand Down Expand Up @@ -241,7 +224,7 @@ void ParallelDimensionMap::adjustMappingsForWarpSpecialization() {
}

// Apply the pad
ws_with_register_sharing_pad_val_ = pad_n_threads;
warp_specialized_padding_value_ = pad_n_threads;
auto off_set = IrBuilder::create<Val>(pad_n_threads, DataType::Index);
auto current_val = dim_it == dim_map_.end()
? IrBuilder::create<Val>(1, DataType::Index)
Expand Down Expand Up @@ -274,15 +257,15 @@ bool ParallelDimensionMap::isExact(ParallelType pt) const {

Val* ParallelDimensionMap::getRawCompute(ParallelType pt) const {
Val* raw = getRaw(pt);
if (warp_specialized_types_.count(pt)) {
if (isWarpSpecialized(pt)) {
int64_t padded_val = getWarpSpecializationPaddedVal(pt);
return SimplifyingIrBuilder::addExpr(raw, -padded_val);
}
return raw;
}

Val* ParallelDimensionMap::getRawAsync(ParallelType pt) const {
if (warp_specialized_types_.count(pt)) {
if (isWarpSpecialized(pt)) {
return IrBuilder::create<Val>(
getWarpSpecializationPaddedVal(pt), DataType::Index);
}
Expand Down Expand Up @@ -324,7 +307,7 @@ Val* ParallelDimensionMap::getLinearThreadIndexAsync() const {
Val* pt_index = NamedScalar::getParallelIndex(pt);
// Map the padded parallel index to [0, padded_value] range, so the linear
// index will be in range of [0, 128).
if (warp_specialized_types_.count(pt)) {
if (isWarpSpecialized(pt)) {
pt_index = SimplifyingIrBuilder::subExpr(pt_index, getRawCompute(pt));
}
index = SimplifyingIrBuilder::addExpr(
Expand All @@ -336,26 +319,25 @@ Val* ParallelDimensionMap::getLinearThreadIndexAsync() const {

int64_t ParallelDimensionMap::getWarpSpecializationPaddedVal(
ParallelType pt) const {
NVF_ERROR(
warp_specialized_types_.contains(pt), "Can't find ParallelType: ", pt);
if (!ws_with_register_sharing_pt_.has_value()) {
NVF_ERROR(isWarpSpecialized(pt), "Can't find ParallelType: ", pt);
if (!warp_specialized_parallel_type_.has_value()) {
return 1;
}
NVF_ERROR(
ws_with_register_sharing_pt_.value() == pt,
warp_specialized_parallel_type_.value() == pt,
"Can't find padded val for: ",
pt);
return ws_with_register_sharing_pad_val_.value();
return warp_specialized_padding_value_.value();
}

bool ParallelDimensionMap::canUseElectSyncInAsyncWarp() const {
// short-circuit: skip if warp specialization is not enabled
if (warp_specialized_types_.empty()) {
if (!hasWarpSpecialization()) {
return true;
}
// Currently only support one warp specialized axis
NVF_ERROR(warp_specialized_types_.size() == 1);
ParallelType ws_pt = *warp_specialized_types_.begin();
NVF_ERROR(warp_specialized_parallel_type_.has_value());
ParallelType ws_pt = warp_specialized_parallel_type_.value();

// Check that BlockDim.x >= 32 active threads in AsyncWarp
if (ws_pt != ParallelType::TIDx) {
Expand Down
17 changes: 9 additions & 8 deletions csrc/parallel_dimension_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,12 @@ class ParallelDimensionMap {

//! Get if the kernel uses warp specialization
bool hasWarpSpecialization() const {
return !warp_specialized_types_.empty();
return warp_specialized_parallel_type_.has_value();
}

//! Check if ParallelType is WarpSpecialized parallel type.
bool isWarpSpecialized(ParallelType pt) const {
return warp_specialized_parallel_type_.value_or(ParallelType::Serial) == pt;
}

bool has(ParallelType pt) const {
Expand Down Expand Up @@ -111,12 +116,8 @@ class ParallelDimensionMap {
//! exactly the same as extents of mapped domains.
std::unordered_set<ParallelType> exact_types_;

//! Set of parallel types that we are doing warp specialization on
std::unordered_set<ParallelType> warp_specialized_types_;

//! warp specialization with register sharing, keep track of
//! the parallel type and the threads padding
std::optional<ParallelType> ws_with_register_sharing_pt_;
std::optional<int64_t> ws_with_register_sharing_pad_val_;
//! Keep track of warp specialized parallel type and padding value
std::optional<ParallelType> warp_specialized_parallel_type_;
std::optional<int64_t> warp_specialized_padding_value_;
};
} // namespace nvfuser
22 changes: 16 additions & 6 deletions tests/cpp/test_circular_buffering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1096,12 +1096,22 @@ class TmaCircularBufferingTest
NVFuserTest::SetUp();
}

bool testEnablesWarpSpecialization() {
return std::holds_alternative<WarpSpecialized>(circular_buffer_type);
}

bool testEnablesRegisterSharing() {
return std::holds_alternative<WarpSpecialized>(circular_buffer_type) &&
std::get<WarpSpecialized>(circular_buffer_type)
.num_registers.has_value();
}

bool testEnablesTIDx() {
return testEnablesWarpSpecialization() &&
std::get<WarpSpecialized>(circular_buffer_type).on ==
ParallelType::TIDx;
}

bool testEnablesRegisterSharingTIDx() {
return testEnablesRegisterSharing() &&
std::get<WarpSpecialized>(circular_buffer_type).on ==
Expand Down Expand Up @@ -1776,8 +1786,8 @@ TEST_P(TmaCircularBufferingTest, OuterReduction) {

TEST_P(TmaCircularBufferingTest, Persistent) {
NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
if (testEnablesRegisterSharing()) {
GTEST_SKIP() << "Bdimx is dynamic, register Sharing is disabled";
if (testEnablesWarpSpecialization()) {
GTEST_SKIP() << "Bdimx is dynamic, Warp Specialization is disabled.";
return;
}

Expand Down Expand Up @@ -1917,9 +1927,9 @@ TEST_P(TmaCircularBufferingTest, Persistent) {
TEST_P(TmaCircularBufferingTest, Matmul) {
NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);

if (testEnablesRegisterSharingTIDx()) {
if (testEnablesTIDx()) {
GTEST_SKIP()
<< "Register Sharing with TIDx used for both computation and load, requires TIDx to be a multiple of 128.";
<< "Warp Specialization with TIDx used for both computation and load, requires TIDx to be a multiple of 128.";
return;
}

Expand Down Expand Up @@ -2060,9 +2070,9 @@ TEST_P(TmaCircularBufferingTest, Matmul) {
TEST_P(TmaCircularBufferingTest, MatmulWithBroadcastedInput) {
NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);

if (testEnablesRegisterSharingTIDx()) {
if (testEnablesTIDx()) {
GTEST_SKIP()
<< "Register Sharing with TIDx used for both computation and load, requires TIDx to be a multiple of 128.";
<< "Warp Specialization with TIDx used for both computation and load, requires TIDx to be a multiple of 128.";
return;
}

Expand Down
17 changes: 13 additions & 4 deletions tests/cpp/test_combined_inner_outer_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1071,10 +1071,13 @@ class TmaWarpSpecializedTest

TEST_P(TmaWarpSpecializedTest, SimpleFusion) {
NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
auto [contig, _, dtype, dim0, dim1] = GetParam();
if (!contig) {
GTEST_SKIP() << "TMA load requires contig inner domain.";
auto [contig, ws_enabled, dtype, dim0, dim1] = GetParam();

if (ws_enabled) {
GTEST_SKIP() << "Bdimx is dynamic, Warp Specialization is disabled.";
return;
}

auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
auto tv0 = makeContigTensor(2, dtype);
Expand Down Expand Up @@ -1109,7 +1112,13 @@ TEST_P(TmaWarpSpecializedTest, SimpleFusion) {

TEST_P(TmaWarpSpecializedTest, RMSNormBwd) {
NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
auto [contig, _, dtype, dim0, dim1] = GetParam();
auto [contig, ws_enabled, dtype, dim0, dim1] = GetParam();

if (ws_enabled) {
GTEST_SKIP() << "Bdimx is dynamic, Warp Specialization is disabled.";
return;
}

std::vector<int64_t> norm_shape{dim1};

auto fusion = std::make_unique<Fusion>();
Expand Down