Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2025,17 +2025,17 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i
:param input_precision: How to exercise the Tensor Cores for f32 x f32. If
the device does not have Tensor Cores or the inputs are not of dtype f32,
this option is ignored. For devices that do have tensor cores, the
default precision is tf32.
:type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32"`. Available options for amd: :code:`"ieee"`, (CDNA3 only) :code:`"tf32"`.
:param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32".
default precision is tf32x3.
:type input_precision: string. Available options for nvidia: :code:`"tf32"`, :code:`"tf32x3"`, :code:`"ieee"`. Default: :code:`"tf32x3"`. Available options for amd: :code:`"ieee"`, (CDNA3 only) :code:`"tf32"`.
:param allow_tf32: *Deprecated.* If true, input_precision is set to "tf32x3".
Only one of :code:`input_precision` and :code:`allow_tf32` can be
specified (i.e. at least one must be :code:`None`).
"""
assert input_precision is None or allow_tf32 is None, "Only one of input_precision and allow_tf32 can be specified"
if input_precision is None:
supports_tf32 = "tf32" in _semantic.builder.options.allowed_dot_input_precisions
input_precision = knobs.language.fp32_default or ("tf32" if (supports_tf32 and
(allow_tf32 or allow_tf32 is None)) else "ieee")
supports_tf32 = "tf32x3" in _semantic.builder.options.allowed_dot_input_precisions
input_precision = knobs.language.fp32_default or ("tf32x3" if (supports_tf32 and
(allow_tf32 or allow_tf32 is None)) else "ieee")

input_precision = _unwrap_if_constexpr(input_precision)
out_dtype = _unwrap_if_constexpr(out_dtype)
Expand Down
2 changes: 1 addition & 1 deletion python/triton/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ class InterpreterOptions:
arch: Optional[str] = None
supported_fp8_dtypes: Tuple[str, ...] = ("fp8e5", "fp8e5b16", "fp8e4nv", "fp8e4b8", "fp8e4b15")
deprecated_fp8_dot_operand_dtypes: Tuple[str, ...] = ()
default_dot_input_precision: str = "tf32"
default_dot_input_precision: str = "tf32x3"
allowed_dot_input_precisions: Tuple[str, ...] = ("tf32", "tf32x3", "ieee")
max_num_imprecise_acc_default: int = 0
backend_name: str = "interpreter"
Expand Down
Loading