diff --git a/python/test/unit/language/test_frontend.py b/python/test/unit/language/test_frontend.py index f7bd51bd591c..253836c3faae 100644 --- a/python/test/unit/language/test_frontend.py +++ b/python/test/unit/language/test_frontend.py @@ -1072,3 +1072,15 @@ def test_aggregate_replace_ir(): # Original aggregate still references original tensor. # CHECK: call @{{.*}}anchor{{.*}}([[A]]) anchor(state.vals) + + +def test_dot_fp16_accumulator(): + + @triton.jit + def fp16_acc_kernel(): + c = tl.zeros([16, 16], dtype=tl.float16) + a = tl.full([16, 16], 1, dtype=tl.float16) + b = tl.full([16, 16], 1, dtype=tl.float16) + tl.dot(a, b, c) + + run_parser(fp16_acc_kernel) diff --git a/python/triton/language/core.py b/python/triton/language/core.py index bdafc0788294..be44b21ff8c2 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -2350,7 +2350,7 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas @builtin -def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32, +def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=None, _semantic=None): """ Returns the matrix product of two blocks. diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index ebc26b843943..d37e6c86c14f 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1420,7 +1420,7 @@ def _str_to_dot_input_precision(self, input_precision): return getattr(ir.INPUT_PRECISION, input_precision) def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Optional[str], - max_num_imprecise_acc: int, out_dtype: tl.dtype) -> TensorTy: + max_num_imprecise_acc: int, out_dtype: tl.dtype | None) -> TensorTy: assert lhs.type.is_block() and rhs.type.is_block() if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): @@ -1457,6 +1457,9 @@ def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Opti if input_precision is None: input_precision = self.builder.options.default_dot_input_precision + if out_dtype is None: + out_dtype = tl.float32 if acc is None else acc.type.element_ty + input_precision = self._str_to_dot_input_precision(input_precision) lhs_rank = len(lhs.shape)