From e6195a80375a0c24665c1179c95e41a27debd23f Mon Sep 17 00:00:00 2001 From: mrTsjolder Date: Fri, 22 May 2026 21:06:23 +0200 Subject: [PATCH 1/2] use acc dtype if out_dtype is not specified instead of float32 --- python/triton/language/core.py | 2 +- python/triton/language/semantic.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 38793213fae9..9963c4529c69 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -2350,7 +2350,7 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas @builtin -def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32, +def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=None, _semantic=None): """ Returns the matrix product of two blocks. diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index ebc26b843943..d37e6c86c14f 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1420,7 +1420,7 @@ def _str_to_dot_input_precision(self, input_precision): return getattr(ir.INPUT_PRECISION, input_precision) def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Optional[str], - max_num_imprecise_acc: int, out_dtype: tl.dtype) -> TensorTy: + max_num_imprecise_acc: int, out_dtype: tl.dtype | None) -> TensorTy: assert lhs.type.is_block() and rhs.type.is_block() if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): @@ -1457,6 +1457,9 @@ def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Opti if input_precision is None: input_precision = self.builder.options.default_dot_input_precision + if out_dtype is None: + out_dtype = tl.float32 if acc is None else acc.type.element_ty + input_precision = self._str_to_dot_input_precision(input_precision) lhs_rank = len(lhs.shape) From 865db444fd1715bc33b3c64af395aff7899a36e6 Mon Sep 17 00:00:00 2001 From: mrTsjolder Date: Fri, 22 May 2026 21:07:13 +0200 Subject: [PATCH 2/2] add test for tl.dot with fp16 accumulator --- python/test/unit/language/test_frontend.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/python/test/unit/language/test_frontend.py b/python/test/unit/language/test_frontend.py index f7bd51bd591c..253836c3faae 100644 --- a/python/test/unit/language/test_frontend.py +++ b/python/test/unit/language/test_frontend.py @@ -1072,3 +1072,15 @@ def test_aggregate_replace_ir(): # Original aggregate still references original tensor. # CHECK: call @{{.*}}anchor{{.*}}([[A]]) anchor(state.vals) + + +def test_dot_fp16_accumulator(): + + @triton.jit + def fp16_acc_kernel(): + c = tl.zeros([16, 16], dtype=tl.float16) + a = tl.full([16, 16], 1, dtype=tl.float16) + b = tl.full([16, 16], 1, dtype=tl.float16) + tl.dot(a, b, c) + + run_parser(fp16_acc_kernel)