diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 2c0b780952..2cb307395f 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -2009,10 +2009,25 @@ class TritonGPURemoveLayoutConversionsPass } continue; } - // TODO: propagate through scf.yield by updating parent op result - // types, scf.for iter_args, and init values to match srcEnc. - if (isa(user)) + // scf.yield passes values through to the parent op's results. + // For ForOp/WhileOp, the parent results are tied to block arguments + // and init operands via loop-carried dependencies — in-place type + // rewriting cannot safely update all of them, so block propagation. + // For IfOp, the results are simple branches with no loop-carried + // deps, so propagation is safe if we also follow the IfOp results. + if (auto yieldOp = dyn_cast(user)) { + Operation *parent = yieldOp->getParentOp(); + if (isa(parent)) + return false; + if (auto ifOp = dyn_cast(parent)) { + for (Value result : ifOp.getResults()) { + if (isa(result.getType())) + worklist.push_back(result); + } + continue; + } return false; + } // Any other user (dot, reduce, another convert, etc.) blocks // propagation. return false; @@ -2034,6 +2049,7 @@ class TritonGPURemoveLayoutConversionsPass // Collect all ops that need type rewriting (forward from convert users). SmallVector opsToRewrite; + SetVector ifOpsToRewrite; SmallVector worklist = {dst}; DenseSet visited; @@ -2043,8 +2059,20 @@ class TritonGPURemoveLayoutConversionsPass continue; for (OpOperand &use : v.getUses()) { Operation *user = use.getOwner(); - if (isa(user) || isa(user)) + if (isa(user)) + continue; + // For scf.yield under scf.if, follow through to the IfOp results. + // ForOp/WhileOp yields are blocked by canPropagateSrcEncodingThroughUsers. + if (auto yieldOp = dyn_cast(user)) { + if (auto ifOp = dyn_cast(yieldOp->getParentOp())) { + ifOpsToRewrite.insert(ifOp.getOperation()); + for (Value result : ifOp.getResults()) { + if (isa(result.getType())) + worklist.push_back(result); + } + } continue; + } opsToRewrite.push_back(user); for (Value result : user->getResults()) { if (isa(result.getType())) @@ -2116,6 +2144,14 @@ class TritonGPURemoveLayoutConversionsPass } } } + // Rewrite IfOp result types that we propagated through. + for (Operation *op : ifOpsToRewrite) { + for (Value result : op->getResults()) { + if (auto ty = dyn_cast(result.getType())) { + result.setType(ty.cloneWithEncoding(srcEnc)); + } + } + } // Replace all uses of the convert result with the convert source. dst.replaceAllUsesWith(src); diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 85e1326117..b5bdd59038 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -2093,9 +2093,9 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i :param input_precision: How to exercise the Tensor Cores for f32 x f32. If the device does not have Tensor Cores or the inputs are not of dtype f32, this option is ignored. For devices that do have tensor cores, the - default precision is tf32. - :type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Available options for amd: :code:`"ieee"`, (CDNA3 only) :code:`"tf32"`. - :param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32". + default precision is tf32x3. + :type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32x3"`. Available options for amd: :code:`"ieee"`, (CDNA3 only) :code:`"tf32"`. + :param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32x3". Only one of :code:`input_precision` and :code:`allow_tf32` can be specified (i.e. at least one must be :code:`None`). :param attrs: Optional dictionary of string-valued attributes to attach to the dot operation. diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 9a6310d805..cb1fd5d738 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1629,8 +1629,8 @@ def dot_precheck( assert input_precision is None or tl._unwrap_if_constexpr(allow_tf32) is None, ( "Only one of input_precision and allow_tf32 can be specified") if input_precision is None: - supports_tf32 = "tf32" in self.builder.options.allowed_dot_input_precisions - input_precision = knobs.language.fp32_default or ("tf32" if + supports_tf32 = "tf32x3" in self.builder.options.allowed_dot_input_precisions + input_precision = knobs.language.fp32_default or ("tf32x3" if (supports_tf32 and (allow_tf32 or allow_tf32 is None)) else "ieee") diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index c5325dfe94..13b3aeb210 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -134,7 +134,7 @@ class InterpreterOptions: arch: Optional[str] = None supported_fp8_dtypes: Tuple[str, ...] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15") deprecated_fp8_dot_operand_dtypes: Tuple[str, ...] = () - default_dot_input_precision: str = "tf32" + default_dot_input_precision: str = "tf32x3" allowed_dot_input_precisions: Tuple[str, ...] = ("tf32", "tf32x3", "ieee") max_num_imprecise_acc_default: int = 0 backend_name: str = "interpreter" diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index cec8046a22..3ce689be65 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -474,7 +474,9 @@ def make_ttgir(mod, metadata, opt, capability): # Budget-aware layout conversion elimination — runs last to ensure # converts whose scratch would exceed SMEM budget are eliminated # after all other passes that may introduce layout conversions. - passes.ttgpuir.add_remove_layout_conversions(pm, smem_budget) + # TODO(njriasan): Re-enable once propagateSrcEncodingAndErase handles + # scf::ForOp/WhileOp loop-carried values correctly. + passes.ttgpuir.add_remove_layout_conversions(pm, 0) pm.run(mod, 'make_ttgir') metadata["tensordesc_meta"] = mod.get_tensordesc_metadata()