Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions python/test/unit/language/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1072,3 +1072,15 @@ def test_aggregate_replace_ir():
# Original aggregate still references original tensor.
# CHECK: call @{{.*}}anchor{{.*}}([[A]])
anchor(state.vals)


def test_dot_fp16_accumulator():

@triton.jit
def fp16_acc_kernel():
c = tl.zeros([16, 16], dtype=tl.float16)
a = tl.full([16, 16], 1, dtype=tl.float16)
b = tl.full([16, 16], 1, dtype=tl.float16)
tl.dot(a, b, c)

run_parser(fp16_acc_kernel)
2 changes: 1 addition & 1 deletion python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2350,7 +2350,7 @@ def cast(input, dtype: dtype, fp_downcast_rounding: Optional[str] = None, bitcas


@builtin
def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=float32,
def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_imprecise_acc=None, out_dtype=None,
_semantic=None):
"""
Returns the matrix product of two blocks.
Expand Down
5 changes: 4 additions & 1 deletion python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,7 +1420,7 @@ def _str_to_dot_input_precision(self, input_precision):
return getattr(ir.INPUT_PRECISION, input_precision)

def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Optional[str],
max_num_imprecise_acc: int, out_dtype: tl.dtype) -> TensorTy:
max_num_imprecise_acc: int, out_dtype: tl.dtype | None) -> TensorTy:
assert lhs.type.is_block() and rhs.type.is_block()

if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
Expand Down Expand Up @@ -1457,6 +1457,9 @@ def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Opti
if input_precision is None:
input_precision = self.builder.options.default_dot_input_precision

if out_dtype is None:
out_dtype = tl.float32 if acc is None else acc.type.element_ty

input_precision = self._str_to_dot_input_precision(input_precision)

lhs_rank = len(lhs.shape)
Expand Down
Loading