Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion csrc/device_lower/pass/insert_syncs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,15 @@ class ReadAfterWriteSyncs : public kir::ExprMutator {
// Add a fence before TMA store so that writes in the generic proxy is
// visible to the async proxy.
auto scope = scope_.empty() ? nullptr : scope_.back();
auto fence_async = IrBuilder::create<kir::FenceAsyncProxy>();
Expr* fence_async = IrBuilder::create<kir::FenceAsyncProxy>();
// Predicate the fence to select the first warp. TMA store is warp
// collective so ElectSync is not needed.
Val* warp_size = IrBuilder::create<Val>(32L, PrimDataType::UInt64);
Val* select_first_warp = IrBuilder::ltExpr(
NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size);
auto* select_warp_pred =
IrBuilder::create<kir::Predicate>(select_first_warp);
fence_async = fence_async->withPredicate(select_warp_pred);
registerInsertBefore(expr, fence_async, scope);
}

Expand Down
12 changes: 7 additions & 5 deletions csrc/device_lower/pass/unroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,13 @@ void UnrollPass::dispatch(Expr* expr) {
return;
}

// short-circuit: mbarrier_init or mbarrier_inval with elect sync predicate.
// predicate is specified for tma load with circular buffering.
bool is_mbarrier_init = expr->isA<kir::MBarrierInit>();
bool is_mbarrier_inval = expr->isA<kir::MBarrierInvalidate>();
if ((is_mbarrier_init || is_mbarrier_inval) && expr->predicate() != nullptr) {
// predicate is specified for tma load with circular buffering. Also
// FenceAsyncProxy is predicated for TMA store
if (expr->predicate() != nullptr &&
expr->isOneOf<
kir::MBarrierInit,
kir::MBarrierInvalidate,
kir::FenceAsyncProxy>()) {
kir::IfThenElse* inline_ite =
IrBuilder::create<kir::IfThenElse>(expr->predicate());
kir::ExprMutator::registerReplace(expr, inline_ite);
Expand Down