Skip to content
54 changes: 54 additions & 0 deletions csrc/parallel_dimension_map.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,39 @@ Val* ParallelDimensionMap::getNumComputeThreadsEachBlock() const {
return num_threads;
}

// For warp-specialization, the CTA is padded so the AsyncWarp contains 128
// threads. This function maps the AsyncWarp CTA to a linear index from
// [0, 128). It is used to divide AsyncWarp into four independent warps.
Val* ParallelDimensionMap::getLinearThreadIndexAsync() const {
Val* index = GpuLower::current()->kernel()->zeroVal();
Val* extent = GpuLower::current()->kernel()->oneVal();

for (auto pt : kParallelTypeTIDs) {
// For warp-specialization, an axis is padded so the AsyncWarp contains
// 128 threads.
Val* extent_for_pdim = getRawAsync(pt);
// short-circuit: extent_for_pdim is not used in kernel.
if (extent_for_pdim == nullptr) {
continue;
}
// short-circuit: extent_for_pdim is trivial.
if (extent_for_pdim->isConstScalar() &&
extent_for_pdim->evaluate().as<int64_t>() == 1) {
continue;
}
Val* pt_index = NamedScalar::getParallelIndex(pt);
// Map the padded parallel index to [0, padded_value] range, so the linear
// index will be in range of [0, 128).
if (warp_specialized_types_.count(pt)) {
pt_index = SimplifyingIrBuilder::subExpr(pt_index, getRawCompute(pt));
}
index = SimplifyingIrBuilder::addExpr(
index, SimplifyingIrBuilder::mulExpr(pt_index, extent));
extent = SimplifyingIrBuilder::mulExpr(extent, extent_for_pdim);
}
return index;
}

int64_t ParallelDimensionMap::getWarpSpecializationPaddedVal(
ParallelType pt) const {
NVF_ERROR(
Expand All @@ -315,6 +348,27 @@ int64_t ParallelDimensionMap::getWarpSpecializationPaddedVal(
return ws_with_register_sharing_pad_val_.value();
}

bool ParallelDimensionMap::canUseElectSyncInAsyncWarp() const {
// short-circuit: skip if warp specialization is not enabled
if (warp_specialized_types_.empty()) {
return true;
}
// Currently only support one warp specialized axis
NVF_ERROR(warp_specialized_types_.size() == 1);
ParallelType ws_pt = *warp_specialized_types_.begin();

// Check that BlockDim.x >= 32 active threads in AsyncWarp
if (ws_pt != ParallelType::TIDx) {
return true;
}

if (getWarpSpecializationPaddedVal(ws_pt) >= 32) {
return true;
}

return false;
}

std::string ParallelDimensionMap::toString() const {
std::stringstream ss;
for (auto pt : kParallelTypeThreads) {
Expand Down
7 changes: 7 additions & 0 deletions csrc/parallel_dimension_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ class ParallelDimensionMap {
//! buffer tensors.
Val* getNumComputeThreadsEachBlock() const;

//! Assign linear index to each thread of CTA. Assume (TDZ, TDY, TDX) order.
Val* getLinearThreadIndexAsync() const;

//! Get if the kernel uses warp specialization
bool hasWarpSpecialization() const {
return !warp_specialized_types_.empty();
Expand All @@ -82,6 +85,10 @@ class ParallelDimensionMap {
return dim_map_.count(pt) > 0;
}

// If warp specialized on TIDx and padded value is less than 32 threads, then
// elect-sync cannot be used.
bool canUseElectSyncInAsyncWarp() const;

private:
//! Get number of threads for ParallelType axis
//! Not used: 1, Const: n, Dynamic: -1
Expand Down
121 changes: 64 additions & 57 deletions csrc/predicate_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,38 +488,63 @@ std::size_t UnswitchPredicateKeyHash::operator()(

namespace {

// Select first warp of threads along TIDx axis and then use ptx::elect_sync
// Create elect-sync to pick a thread
Val* createElectSyncExpr() {
Val* full_mask_val = IrBuilder::create<Val>(0xFFFFFFFF, PrimDataType::UInt32);
Val* elect_sync_val = IrBuilder::create<Val>(PrimDataType::Bool);
IrBuilder::create<UnaryOp>(
UnaryOpType::ElectSync, elect_sync_val, full_mask_val);
return elect_sync_val;
}

// Select first warp of threads along TIDx axis and use ptx::elect_sync if not
// warp collective.
// TODO If TIDx is known at compile-time, generate custom mask.
Val* createElectSyncPredicate(
bool use_first_warp = true,
bool is_warp_collective = false) {
// If TIDx is used for both computation and TMA load, we should select a
// thread from the last warp along TIDx.
Val* selectFirstWarpElectSyncPredicate(bool is_warp_collective) {
Val* warp_size = IrBuilder::create<Val>(32L, PrimDataType::UInt64);
auto select_warp = use_first_warp
? IrBuilder::ltExpr(
NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size)
: IrBuilder::geExpr(
NamedScalar::getParallelIndex(ParallelType::TIDx),
IrBuilder::addExpr(
NamedScalar::getParallelDim(ParallelType::TIDx),
IrBuilder::create<Val>(-32L, PrimDataType::Index)));
Val* select_first_warp = IrBuilder::ltExpr(
NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size);

// Short-Circuit: TMA Store is a warp-collective, so ElectSync is not
// necessary.
if (is_warp_collective) {
return select_warp;
return select_first_warp;
}

// Create ElectSync Predicate to pick any thread in the warp
Val* full_mask_val = IrBuilder::create<Val>(0xFFFFFFFF, PrimDataType::UInt32);
Val* elect_sync_val = IrBuilder::create<Val>(PrimDataType::Bool);
IrBuilder::create<UnaryOp>(
UnaryOpType::ElectSync, elect_sync_val, full_mask_val);
return SimplifyingIrBuilder::logicalAndExpr(elect_sync_val, select_warp);
return SimplifyingIrBuilder::logicalAndExpr(
createElectSyncExpr(), select_first_warp);
}

Val* createElectSyncPredicate(kir::Predicate* pred) {
// Get linear index for AsyncWarp Group. Then, select first warp. Finally, use
// ptx::elect_sync if not warp collective.
// TODO If TIDx is known at compile-time, generate custom mask.
Val* createElectSyncPredicateAsync() {
Val* zero = IrBuilder::create<Val>(0L, PrimDataType::UInt64);
Val* warp_size = IrBuilder::create<Val>(32L, PrimDataType::UInt64);

const ParallelDimensionMap& pdim_map =
GpuLower::current()->parallelDimensionMap();
Val* async_warp_thread_index = pdim_map.getLinearThreadIndexAsync();
Val* warp_id =
SimplifyingIrBuilder::divExpr(async_warp_thread_index, warp_size);
// TODO Only select first warp now
Val* select_warp = SimplifyingIrBuilder::eqExpr(warp_id, zero);

// Use elect-sync if available
if (pdim_map.canUseElectSyncInAsyncWarp()) {
return SimplifyingIrBuilder::logicalAndExpr(
select_warp, createElectSyncExpr());
}

// Warp Specialized ParallelType is ThreadIdx.x and it contains less than 32
// threads, so manually select first thread in warp.
Val* thread_id =
SimplifyingIrBuilder::modExpr(async_warp_thread_index, warp_size);
Val* select_thread = SimplifyingIrBuilder::eqExpr(thread_id, zero);
return SimplifyingIrBuilder::logicalAndExpr(select_warp, select_thread);
}

Val* createElectSyncPredicate(kir::Predicate* pred, bool is_async_warp) {
NVF_ERROR(pred != nullptr);
NVF_ERROR(pred->expr() != nullptr);

Expand Down Expand Up @@ -558,7 +583,12 @@ Val* createElectSyncPredicate(kir::Predicate* pred) {
NamedScalar::getParallelIndex(ParallelType::TIDx), zero);
}
}
return createElectSyncPredicate(/*use_first_warp=*/true, is_tma_store);

NVF_ERROR(!(is_tma_store && is_async_warp));
if (is_async_warp) {
return createElectSyncPredicateAsync();
}
return selectFirstWarpElectSyncPredicate(is_tma_store);
}

Val* createSingleExpressionElectSync(
Expand All @@ -579,6 +609,10 @@ Val* createSingleExpressionElectSync(
auto pred_map =
ParallelizedDomainPredicate::getPredicateMap(pred->expr(), loops);

bool is_async_warp = std::any_of(loops.begin(), loops.end(), [](ForLoop* fl) {
return fl->circularBufferLoopStage() == CircularBufferLoopStage::AsyncWarp;
});

Val* parallel_dom_pred = GpuLower::current()->kernel()->trueVal();
for (auto pt : {ParallelType::TIDx, ParallelType::TIDy, ParallelType::TIDz}) {
// short-circuit: parallelDim is not used by CTA
Expand Down Expand Up @@ -607,7 +641,7 @@ Val* createSingleExpressionElectSync(
if (pt == ParallelType::TIDx) {
// Use createElectSyncPredicate for ParallelDim::TIDx.
parallel_dom_pred = SimplifyingIrBuilder::logicalAndExpr(
parallel_dom_pred, createElectSyncPredicate(pred));
parallel_dom_pred, createElectSyncPredicate(pred, is_async_warp));
} else {
// Select first element of dimension for ParallelDim::TIDy and
// ParallelDim::TIDz.
Expand Down Expand Up @@ -648,25 +682,22 @@ Val* createMultipleExpressionElectSync(
return fl->circularBufferLoopStage() ==
CircularBufferLoopStage::AsyncWarp;
});
bool is_register_sharing = false;
if (async_warp_loop_it != loops.end()) {
auto circular_buffer_type = std::get<WarpSpecialized>(
GpuLower::current()
->circularBufferInfo()
.getCircularBufferOptionsFor((*async_warp_loop_it)->iter_domain())
.type);
async_warp_on = circular_buffer_type.on;
is_register_sharing = circular_buffer_type.num_registers.has_value();
}

// Short-circuit: register sharing is not used, don't need to pad a full warp
// group. If we are in a async warp, then the warp-dispatching IfThenElse
// already selects on `async_warp_on`, so we should not generate
// predicates for it here.
if (!is_register_sharing) {
// Short-circuit: If we are in a async warp, then the warp-dispatching
// IfThenElse already selects on `async_warp_on`, so we should not
// generate predicates for it here.
if (async_warp_loop_it == loops.end()) {
Val* conditional = async_warp_on == ParallelType::TIDx
? pred->fusion()->trueVal()
: createElectSyncPredicate();
: selectFirstWarpElectSyncPredicate(/*is_warp_collective=*/false);
for (ParallelType pt : {ParallelType::TIDy, ParallelType::TIDz}) {
if (pdim_map.has(pt) && async_warp_on != pt) {
conditional = SimplifyingIrBuilder::logicalAndExpr(
Expand All @@ -677,31 +708,7 @@ Val* createMultipleExpressionElectSync(
return conditional;
}

// If not specialized on TIDx, load branch has full size of bdimx,
// we can use the first warp, otherwise should use the last warp.
bool use_first_warp = async_warp_on != ParallelType::TIDx;
Val* conditional = createElectSyncPredicate(use_first_warp);
for (ParallelType pt : {ParallelType::TIDy, ParallelType::TIDz}) {
if (!pdim_map.has(pt)) {
continue;
}
if (async_warp_on != pt) {
// Not specialized on pt, use the first thread.
conditional = SimplifyingIrBuilder::logicalAndExpr(
conditional,
IrBuilder::eqExpr(NamedScalar::getParallelIndex(pt), zero));
} else {
// Specialized on pt, use the last thread.
Val* raw = GpuLower::current()->parallelDimensionMap().get(async_warp_on);
conditional = SimplifyingIrBuilder::logicalAndExpr(
conditional,
IrBuilder::eqExpr(
NamedScalar::getParallelIndex(pt),
IrBuilder::subExpr(
raw, IrBuilder::create<Val>(1, DataType::Index))));
}
}
return conditional;
return createElectSyncPredicateAsync();
}

} // namespace
Expand Down