Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 40 additions & 4 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2009,10 +2009,25 @@ class TritonGPURemoveLayoutConversionsPass
}
continue;
}
// TODO: propagate through scf.yield by updating parent op result
// types, scf.for iter_args, and init values to match srcEnc.
if (isa<scf::YieldOp>(user))
// scf.yield passes values through to the parent op's results.
// For ForOp/WhileOp, the parent results are tied to block arguments
// and init operands via loop-carried dependencies — in-place type
// rewriting cannot safely update all of them, so block propagation.
// For IfOp, the results are simple branches with no loop-carried
// deps, so propagation is safe if we also follow the IfOp results.
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
Operation *parent = yieldOp->getParentOp();
if (isa<scf::ForOp, scf::WhileOp>(parent))
return false;
if (auto ifOp = dyn_cast<scf::IfOp>(parent)) {
for (Value result : ifOp.getResults()) {
if (isa<RankedTensorType>(result.getType()))
worklist.push_back(result);
}
continue;
}
return false;
}
// Any other user (dot, reduce, another convert, etc.) blocks
// propagation.
return false;
Expand All @@ -2034,6 +2049,7 @@ class TritonGPURemoveLayoutConversionsPass

// Collect all ops that need type rewriting (forward from convert users).
SmallVector<Operation *> opsToRewrite;
SetVector<Operation *> ifOpsToRewrite;
SmallVector<Value> worklist = {dst};
DenseSet<Value> visited;

Expand All @@ -2043,8 +2059,20 @@ class TritonGPURemoveLayoutConversionsPass
continue;
for (OpOperand &use : v.getUses()) {
Operation *user = use.getOwner();
if (isa<LocalStoreOp>(user) || isa<scf::YieldOp>(user))
if (isa<LocalStoreOp>(user))
continue;
// For scf.yield under scf.if, follow through to the IfOp results.
// ForOp/WhileOp yields are blocked by canPropagateSrcEncodingThroughUsers.
if (auto yieldOp = dyn_cast<scf::YieldOp>(user)) {
if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
ifOpsToRewrite.insert(ifOp.getOperation());
for (Value result : ifOp.getResults()) {
if (isa<RankedTensorType>(result.getType()))
worklist.push_back(result);
}
}
continue;
}
opsToRewrite.push_back(user);
for (Value result : user->getResults()) {
if (isa<RankedTensorType>(result.getType()))
Expand Down Expand Up @@ -2116,6 +2144,14 @@ class TritonGPURemoveLayoutConversionsPass
}
}
}
// Rewrite IfOp result types that we propagated through.
for (Operation *op : ifOpsToRewrite) {
for (Value result : op->getResults()) {
if (auto ty = dyn_cast<RankedTensorType>(result.getType())) {
result.setType(ty.cloneWithEncoding(srcEnc));
}
}
}

// Replace all uses of the convert result with the convert source.
dst.replaceAllUsesWith(src);
Expand Down
6 changes: 3 additions & 3 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2093,9 +2093,9 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i
:param input_precision: How to exercise the Tensor Cores for f32 x f32. If
the device does not have Tensor Cores or the inputs are not of dtype f32,
this option is ignored. For devices that do have tensor cores, the
default precision is tf32.
:type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Available options for amd: :code:`"ieee"`, (CDNA3 only) :code:`"tf32"`.
:param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32".
default precision is tf32x3.
:type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32x3"`. Available options for amd: :code:`"ieee"`, (CDNA3 only) :code:`"tf32"`.
:param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32x3".
Only one of :code:`input_precision` and :code:`allow_tf32` can be
specified (i.e. at least one must be :code:`None`).
:param attrs: Optional dictionary of string-valued attributes to attach to the dot operation.
Expand Down
4 changes: 2 additions & 2 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1629,8 +1629,8 @@ def dot_precheck(
assert input_precision is None or tl._unwrap_if_constexpr(allow_tf32) is None, (
"Only one of input_precision and allow_tf32 can be specified")
if input_precision is None:
supports_tf32 = "tf32" in self.builder.options.allowed_dot_input_precisions
input_precision = knobs.language.fp32_default or ("tf32" if
supports_tf32 = "tf32x3" in self.builder.options.allowed_dot_input_precisions
input_precision = knobs.language.fp32_default or ("tf32x3" if
(supports_tf32 and
(allow_tf32 or allow_tf32 is None)) else "ieee")

Expand Down
2 changes: 1 addition & 1 deletion python/triton/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class InterpreterOptions:
arch: Optional[str] = None
supported_fp8_dtypes: Tuple[str, ...] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15")
deprecated_fp8_dot_operand_dtypes: Tuple[str, ...] = ()
default_dot_input_precision: str = "tf32"
default_dot_input_precision: str = "tf32x3"
allowed_dot_input_precisions: Tuple[str, ...] = ("tf32", "tf32x3", "ieee")
max_num_imprecise_acc_default: int = 0
backend_name: str = "interpreter"
Expand Down
4 changes: 3 additions & 1 deletion third_party/nvidia/backend/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,9 @@ def make_ttgir(mod, metadata, opt, capability):
# Budget-aware layout conversion elimination — runs last to ensure
# converts whose scratch would exceed SMEM budget are eliminated
# after all other passes that may introduce layout conversions.
passes.ttgpuir.add_remove_layout_conversions(pm, smem_budget)
# TODO(njriasan): Re-enable once propagateSrcEncodingAndErase handles
# scf::ForOp/WhileOp loop-carried values correctly.
passes.ttgpuir.add_remove_layout_conversions(pm, 0)

pm.run(mod, 'make_ttgir')
metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()
Expand Down
Loading