Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 37 additions & 19 deletions csrc/predicate_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,14 +490,12 @@ namespace {

// Select first warp of threads along TIDx axis and then use ptx::elect_sync
// TODO If TIDx is known at compile-time, generate custom mask.
Val* createElectSyncPredicate(bool use_first_warp = true) {
Val* warp_size = IrBuilder::create<Val>(32L, PrimDataType::UInt64);
Val* full_mask_val = IrBuilder::create<Val>(0xFFFFFFFF, PrimDataType::UInt32);
Val* elect_sync_val = IrBuilder::create<Val>(PrimDataType::Bool);
IrBuilder::create<UnaryOp>(
UnaryOpType::ElectSync, elect_sync_val, full_mask_val);
Val* createElectSyncPredicate(
bool use_first_warp = true,
bool is_warp_collective = false) {
// If TIDx is used for both computation and TMA load, we should select a
// thread from the last warp along TIDx.
Val* warp_size = IrBuilder::create<Val>(32L, PrimDataType::UInt64);
auto select_warp = use_first_warp
? IrBuilder::ltExpr(
NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size)
Expand All @@ -506,6 +504,18 @@ Val* createElectSyncPredicate(bool use_first_warp = true) {
IrBuilder::addExpr(
NamedScalar::getParallelDim(ParallelType::TIDx),
IrBuilder::create<Val>(-32L, PrimDataType::Index)));

// Short-Circuit: TMA Store is a warp-collective, so ElectSync is not
// necessary.
if (is_warp_collective) {
return select_warp;
}

// Create ElectSync Predicate to pick any thread in the warp
Val* full_mask_val = IrBuilder::create<Val>(0xFFFFFFFF, PrimDataType::UInt32);
Val* elect_sync_val = IrBuilder::create<Val>(PrimDataType::Bool);
IrBuilder::create<UnaryOp>(
UnaryOpType::ElectSync, elect_sync_val, full_mask_val);
return SimplifyingIrBuilder::logicalAndExpr(elect_sync_val, select_warp);
}

Expand Down Expand Up @@ -537,14 +547,18 @@ Val* createElectSyncPredicate(kir::Predicate* pred) {
}

// short-circuit: Expect ParallelType::TIDx to have at least one warp.
bool is_tma_store = ir_utils::isCpAsyncBulkStore(pred->expr());
if (tidx_paralleltype_dim->isConstScalar() &&
tidx_paralleltype_dim->evaluate().as<int64_t>() < 32) {
Val* zero = IrBuilder::create<Val>(0L, PrimDataType::UInt64);
return IrBuilder::eqExpr(
NamedScalar::getParallelIndex(ParallelType::TIDx), zero);
if (is_tma_store) {
return pred->fusion()->trueVal();
} else {
Val* zero = IrBuilder::create<Val>(0L, PrimDataType::UInt64);
return IrBuilder::eqExpr(
NamedScalar::getParallelIndex(ParallelType::TIDx), zero);
}
}

return createElectSyncPredicate();
return createElectSyncPredicate(/*use_first_warp=*/true, is_tma_store);
}

Val* createSingleExpressionElectSync(
Expand All @@ -559,9 +573,12 @@ Val* createSingleExpressionElectSync(

TensorView* out_tv = ir_utils::getTvOutput(pred->expr());
Val* zero = IrBuilder::create<Val>(0L, PrimDataType::UInt64);
const auto& pdim_map = GpuLower::current()->parallelDimensionMap();

const ParallelDimensionMap& pdim_map =
GpuLower::current()->parallelDimensionMap();
auto pred_map =
ParallelizedDomainPredicate::getPredicateMap(pred->expr(), loops);

Val* parallel_dom_pred = GpuLower::current()->kernel()->trueVal();
for (auto pt : {ParallelType::TIDx, ParallelType::TIDy, ParallelType::TIDz}) {
// short-circuit: parallelDim is not used by CTA
Expand All @@ -574,10 +591,10 @@ Val* createSingleExpressionElectSync(
// exists.
auto pred_info_it = pred_map.find(pt);
if (pred_info_it != pred_map.end()) {
const auto& pred_info = pred_info_it->second;
auto tid_pred = pred_info.getPredicate();
parallel_dom_pred =
SimplifyingIrBuilder::logicalAndExpr(parallel_dom_pred, tid_pred);
const ParallelizedDomainPredicate::PredicateInfo& pred_info =
pred_info_it->second;
parallel_dom_pred = SimplifyingIrBuilder::logicalAndExpr(
parallel_dom_pred, pred_info.getPredicate());
}

// Case 2: ParallelDim is used by CTA but not the TMA/Blackwell MMA
Expand Down Expand Up @@ -621,7 +638,8 @@ Val* createMultipleExpressionElectSync(
NVF_ERROR(pred->expr() == nullptr);

Val* zero = IrBuilder::create<Val>(0L, PrimDataType::UInt64);
const auto& pdim_map = GpuLower::current()->parallelDimensionMap();
const ParallelDimensionMap& pdim_map =
GpuLower::current()->parallelDimensionMap();

// Determine if warp specialized tma load expression.
ParallelType async_warp_on = ParallelType::Serial;
Expand Down Expand Up @@ -649,7 +667,7 @@ Val* createMultipleExpressionElectSync(
Val* conditional = async_warp_on == ParallelType::TIDx
? pred->fusion()->trueVal()
: createElectSyncPredicate();
for (auto pt : {ParallelType::TIDy, ParallelType::TIDz}) {
for (ParallelType pt : {ParallelType::TIDy, ParallelType::TIDz}) {
if (pdim_map.has(pt) && async_warp_on != pt) {
conditional = SimplifyingIrBuilder::logicalAndExpr(
conditional,
Expand All @@ -663,7 +681,7 @@ Val* createMultipleExpressionElectSync(
// we can use the first warp, otherwise should use the last warp.
bool use_first_warp = async_warp_on != ParallelType::TIDx;
Val* conditional = createElectSyncPredicate(use_first_warp);
for (auto pt : {ParallelType::TIDy, ParallelType::TIDz}) {
for (ParallelType pt : {ParallelType::TIDy, ParallelType::TIDz}) {
if (!pdim_map.has(pt)) {
continue;
}
Expand Down
54 changes: 47 additions & 7 deletions tests/cpp/test_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,14 @@ class XorFinder : private kir::IrVisitor {
class TMAPredicateChecker : private kir::IrVisitor {
int64_t num_threads_;
int64_t cta_threads_;
TMAPredicateChecker(int64_t num_threads, int64_t cta_threads)
: num_threads_(num_threads), cta_threads_(cta_threads) {}
bool is_tma_store_;
TMAPredicateChecker(
int64_t num_threads,
int64_t cta_threads,
bool is_tma_store)
: num_threads_(num_threads),
cta_threads_(cta_threads),
is_tma_store_(is_tma_store) {}

kir::Predicate* pred_ = nullptr;

Expand Down Expand Up @@ -269,8 +275,31 @@ class TMAPredicateChecker : private kir::IrVisitor {
ASSERT_NE(pred_, nullptr);
auto cond = pred_->value();
ASSERT_NE(cond, nullptr);

// Handle TMA Store first
if (is_tma_store_) {
if (cta_threads_ <= 32) {
EXPECT_TRUE(cond->isTrue());
} else {
auto def = dynamic_cast<BinaryOp*>(cond->definition());
ASSERT_TRUE(def != nullptr);
EXPECT_TRUE(def->getBinaryOpType() == BinaryOpType::LT);
auto lhs = dynamic_cast<NamedScalar*>(def->lhs());
auto rhs = def->rhs();
ASSERT_TRUE(lhs != nullptr);
ASSERT_TRUE(rhs != nullptr);
EXPECT_TRUE(lhs->isThreadIdx());
EXPECT_TRUE(rhs->isConstInt());
EXPECT_EQ(rhs->value(), 32);
}
return;
}

// Then, handle TMA Load
if (num_threads_ == 0) {
EXPECT_TRUE(cond->isTrue());
} else if (is_tma_store_ && cta_threads_ <= 32) {
EXPECT_TRUE(cond->isTrue());
} else if (num_threads_ == 1 && cta_threads_ > 32) {
auto def = dynamic_cast<BinaryOp*>(cond->definition());
ASSERT_TRUE(def != nullptr);
Expand Down Expand Up @@ -324,8 +353,9 @@ class TMAPredicateChecker : private kir::IrVisitor {
static void checkPredicate(
kir::Kernel* kernel,
int64_t num_threads,
int64_t cta_threads = -1) {
TMAPredicateChecker checker(num_threads, cta_threads);
int64_t cta_threads = -1,
bool is_tma_store = false) {
TMAPredicateChecker checker(num_threads, cta_threads, is_tma_store);
checker.handle(kernel->topLevelExprs());
}
};
Expand Down Expand Up @@ -611,7 +641,10 @@ TEST_P(TMASimpleLdstTest, Store) {

EXPECT_EQ(TMADimChecker::getDim(ke.compiledKernel()->kernel()), dim);
TMAPredicateChecker::checkPredicate(
ke.compiledKernel()->kernel(), 1, ke.lastLaunchParams().nThreads());
ke.compiledKernel()->kernel(),
1,
ke.lastLaunchParams().nThreads(),
/*is_tma_store=*/true);
ASSERT_EQ(
XorFinder::findXor(ke.compiledKernel()->kernel()),
(swizzle != MmaInputSmemSwizzle::None));
Expand Down Expand Up @@ -2482,7 +2515,11 @@ TEST_F(TMADocTest, Figure14d) {
ke.compile(&fusion, {t0}, {}, matmul_cparams);

EXPECT_EQ(TMADimChecker::getDim(ke.compiledKernel()->kernel()), 2);
TMAPredicateChecker::checkPredicate(ke.compiledKernel()->kernel(), 1);
TMAPredicateChecker::checkPredicate(
ke.compiledKernel()->kernel(),
1,
ke.lastLaunchParams().nThreads(),
/*is_tma_store=*/true);

auto cg_outputs = ke.run({t0});
testValidate(&fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__);
Expand Down Expand Up @@ -2565,7 +2602,10 @@ TEST_F(TMADocTest, Figure14e) {

EXPECT_EQ(TMADimChecker::getDim(ke.compiledKernel()->kernel()), 2);
TMAPredicateChecker::checkPredicate(
ke.compiledKernel()->kernel(), 1, ke.lastLaunchParams().nThreads());
ke.compiledKernel()->kernel(),
1,
ke.lastLaunchParams().nThreads(),
/*is_tma_store=*/true);
}

TEST_F(TMADocTest, Figure15a) {
Expand Down