diff --git a/csrc/device_lower/pass/insert_syncs.cpp b/csrc/device_lower/pass/insert_syncs.cpp index 3d1503c28f4..b1108287292 100644 --- a/csrc/device_lower/pass/insert_syncs.cpp +++ b/csrc/device_lower/pass/insert_syncs.cpp @@ -498,7 +498,15 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { // Add a fence before TMA store so that writes in the generic proxy is // visible to the async proxy. auto scope = scope_.empty() ? nullptr : scope_.back(); - auto fence_async = IrBuilder::create(); + Expr* fence_async = IrBuilder::create(); + // Predicate the fence to select the first warp. TMA store is warp + // collective so ElectSync is not needed. + Val* warp_size = IrBuilder::create(32L, PrimDataType::UInt64); + Val* select_first_warp = IrBuilder::ltExpr( + NamedScalar::getParallelIndex(ParallelType::TIDx), warp_size); + auto* select_warp_pred = + IrBuilder::create(select_first_warp); + fence_async = fence_async->withPredicate(select_warp_pred); registerInsertBefore(expr, fence_async, scope); } diff --git a/csrc/device_lower/pass/unroll.cpp b/csrc/device_lower/pass/unroll.cpp index 42af607ae2f..aaec335f179 100644 --- a/csrc/device_lower/pass/unroll.cpp +++ b/csrc/device_lower/pass/unroll.cpp @@ -51,11 +51,13 @@ void UnrollPass::dispatch(Expr* expr) { return; } - // short-circuit: mbarrier_init or mbarrier_inval with elect sync predicate. - // predicate is specified for tma load with circular buffering. - bool is_mbarrier_init = expr->isA(); - bool is_mbarrier_inval = expr->isA(); - if ((is_mbarrier_init || is_mbarrier_inval) && expr->predicate() != nullptr) { + // predicate is specified for tma load with circular buffering. Also + // FenceAsyncProxy is predicated for TMA store + if (expr->predicate() != nullptr && + expr->isOneOf< + kir::MBarrierInit, + kir::MBarrierInvalidate, + kir::FenceAsyncProxy>()) { kir::IfThenElse* inline_ite = IrBuilder::create(expr->predicate()); kir::ExprMutator::registerReplace(expr, inline_ite);