diff --git a/python/triton/language/core.py b/python/triton/language/core.py index be44b21ff8c2..d192f3ef28d9 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -3250,7 +3250,7 @@ def selu(x, alpha): :return: one tensor or a tuple of tensors, depending on the mapped function. ''' # Build the block for the nested region first to discover the return types - assert pack >= 1 + assert pack >= 1, f"pack must be >= 1, got {pack}" in_scalar_tys = [t.type.scalar for t in args] builder = _semantic.builder block = builder.new_block() @@ -3882,8 +3882,8 @@ def builtin_max(*args, propagate_nan=_NOTHING, _semantic=None): is_constexpr = all(not isinstance(x, base_value) for x in args) if is_constexpr: assert propagate_nan is _NOTHING, "propagate_nan is not supported on builtin max" - assert not any(math.isnan(x) for x in args) - assert not any(is_negative_zero(x) for x in args) + assert not any(math.isnan(x) for x in args), "constexpr max does not support NaN values" + assert not any(is_negative_zero(x) for x in args), "constexpr max does not support negative zero" return constexpr(builtins.max(_unwrap_if_constexpr(args))) if propagate_nan is _NOTHING: @@ -3906,8 +3906,8 @@ def builtin_min(*args, propagate_nan=_NOTHING, _semantic=None): is_constexpr = all(not isinstance(x, base_value) for x in args) if is_constexpr: assert propagate_nan is _NOTHING, "propagate_nan is not supported on builtin min" - assert not any(math.isnan(x) for x in args) - assert not any(is_negative_zero(x) for x in args) + assert not any(math.isnan(x) for x in args), "constexpr min does not support NaN values" + assert not any(is_negative_zero(x) for x in args), "constexpr min does not support negative zero" return constexpr(builtins.min(_unwrap_if_constexpr(args))) if propagate_nan is _NOTHING: diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index d37e6c86c14f..b18693f80893 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -214,8 +214,8 @@ def binary_op_sanitize_overflow_impl(self, lhs: TensorTy, rhs: TensorTy, binary_ return lhs_sca_ty = lhs.type.scalar rhs_sca_ty = rhs.type.scalar - assert lhs_sca_ty == rhs_sca_ty - assert lhs_sca_ty.is_int() + assert lhs_sca_ty == rhs_sca_ty, f"expected matching operand types, got {lhs_sca_ty} and {rhs_sca_ty}" + assert lhs_sca_ty.is_int(), f"expected integer type, got {lhs_sca_ty}" lhs = self.cast(lhs, tl.int64) rhs = self.cast(rhs, tl.int64) ret = binary_op(lhs, rhs, False) @@ -657,7 +657,7 @@ def expand_dims(self, input: TensorTy, axis: int) -> TensorTy: def cat(self, lhs: TensorTy, rhs: TensorTy, can_reorder: bool) -> TensorTy: assert can_reorder, "current implementation of `cat` always may reorder elements" - assert len(lhs.shape) == 1 + assert len(lhs.shape) == 1, f"expected 1D input for cat, got {len(lhs.shape)}D" ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]]) return self.tensor(self.builder.create_cat(lhs.handle, rhs.handle), ret_type) @@ -686,8 +686,9 @@ def join(self, a: TensorTy, b: TensorTy) -> TensorTy: return ret def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]: - assert (len(a.shape) > 0) - assert (tl._unwrap_if_constexpr(a.shape[-1]) == 2) + assert (len(a.shape) > 0), "split requires a non-scalar tensor" + assert (tl._unwrap_if_constexpr(a.shape[-1]) == 2), \ + f"expected last dimension to be 2 for split, got {tl._unwrap_if_constexpr(a.shape[-1])}" new_shape = a.shape[:-1] ret_type = tl.block_type(a.type.scalar, new_shape) @@ -754,7 +755,8 @@ def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy: tl.block_type(rhs_ty.scalar, [1] + rhs_shape.values)) rhs_ty = rhs.type rhs_shape = rhs_ty.get_block_shapes() - assert len(rhs_shape) == len(lhs_shape) + assert len(rhs_shape) == len(lhs_shape), \ + f"expected tensors of equal rank for broadcast, got {len(lhs_shape)} and {len(rhs_shape)}" ret_shape = [] for i, left in enumerate(lhs_shape): @@ -1062,7 +1064,8 @@ def load(self, ptr: TensorTy, mask: Optional[TensorTy], other: Optional[TensorTy def descriptor_load(self, desc: tl.tensor_descriptor_base, offsets, cache_modifier: str, eviction_policy: str) -> TensorTy: - assert isinstance(desc, tl.tensor_descriptor_base) + assert isinstance(desc, tl.tensor_descriptor_base), \ + f"expected a tensor descriptor, got {type(desc).__name__}" ndim = len(desc.block_shape) assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}" @@ -1072,10 +1075,12 @@ def descriptor_load(self, desc: tl.tensor_descriptor_base, offsets, cache_modifi return self.tensor(x, desc.block_type) def validate_store_like(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> None: - assert isinstance(desc, tl.tensor_descriptor_base) + assert isinstance(desc, tl.tensor_descriptor_base), \ + f"expected a tensor descriptor, got {type(desc).__name__}" ndim = len(desc.block_shape) assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}" - assert value.shape == desc.block_shape + assert value.shape == desc.block_shape, \ + f"expected value shape {desc.block_shape}, got {value.shape}" def descriptor_store(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy: self.validate_store_like(desc, value, offsets) @@ -1136,7 +1141,8 @@ def descriptor_atomic_xor(self, desc: tl.tensor_descriptor_base, value: TensorTy return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void) def descriptor_gather(self, desc, x_offsets, y_offset, cache_modifier: str, eviction_policy: str) -> TensorTy: - assert isinstance(desc, tl.tensor_descriptor_base) + assert isinstance(desc, tl.tensor_descriptor_base), \ + f"expected a tensor descriptor, got {desc.__class__.__name__}" assert cache_modifier == "", "cache modifier is not supported yet" assert eviction_policy == "", "eviction policy is not supported yet" @@ -1162,7 +1168,8 @@ def descriptor_gather(self, desc, x_offsets, y_offset, cache_modifier: str, evic return self.tensor(x, type) def descriptor_scatter(self, desc, value: TensorTy, x_offsets, y_offset) -> TensorTy: - assert isinstance(desc, tl.tensor_descriptor_base) + assert isinstance(desc, tl.tensor_descriptor_base), \ + f"expected a tensor descriptor, got {type(desc).__name__}" # Validate descriptor. assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}" @@ -1421,7 +1428,7 @@ def _str_to_dot_input_precision(self, input_precision): def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Optional[str], max_num_imprecise_acc: int, out_dtype: tl.dtype | None) -> TensorTy: - assert lhs.type.is_block() and rhs.type.is_block() + assert lhs.type.is_block() and rhs.type.is_block(), "dot operands must be block tensors (not scalars)" if lhs.dtype.is_fp8() and rhs.dtype.is_fp8(): # All combinations of supported fp8 x fp8 are permitted @@ -1500,7 +1507,11 @@ def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Opti acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0) else: acc_handle = acc.handle - assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype + assert acc.type.shape == ret_ty.shape, \ + f"expected accumulator shape {ret_ty.shape}, got {acc.type.shape}" + assert acc.type.element_ty == out_dtype, \ + f"expected accumulator dtype {out_dtype}, got {acc.type.element_ty}; " \ + f"pass out_dtype={acc.type.element_ty} to use this accumulator dtype" # max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90 if max_num_imprecise_acc is None: @@ -1565,7 +1576,7 @@ def verify_scaled_shape(self, M, N, K, lhs_scale, rhs_scale, scale_factor): def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: TensorTy, rhs_scale: Optional[TensorTy], rhs_format: str, acc: TensorTy | None, fast_math: bool, lhs_k_pack: bool, rhs_k_pack: bool, out_dtype: tl.dtype) -> TensorTy: - assert lhs.type.is_block() and rhs.type.is_block() + assert lhs.type.is_block() and rhs.type.is_block(), "dot_scaled operands must be block tensors (not scalars)" #TODO: validate types. lhs_rank = len(lhs.shape) rhs_rank = len(rhs.shape) @@ -1604,7 +1615,11 @@ def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: T acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0) else: acc_handle = acc.handle - assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype + assert acc.type.shape == ret_ty.shape, \ + f"expected accumulator shape {ret_ty.shape}, got {acc.type.shape}" + assert acc.type.element_ty == out_dtype, \ + f"expected accumulator dtype {out_dtype}, got {acc.type.element_ty}; " \ + f"pass out_dtype={acc.type.element_ty} to use this accumulator dtype" rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle @@ -1842,7 +1857,7 @@ def make_tensor_descriptor(self, base: TensorTy, shape: List[TensorTy], strides: raise ValueError(f"Expected {ndim} strides but got {len(strides)}") if len(block_shape) != ndim: raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}") - assert isinstance(base.dtype, tl.pointer_type) + assert isinstance(base.dtype, tl.pointer_type), f"base must be a pointer type, got {base.dtype}" elem_size = base.dtype.element_ty.primitive_bitwidth // 8 contig_dim_size = tl._unwrap_if_constexpr(block_shape[-1]) if contig_dim_size * elem_size < 16: @@ -1860,7 +1875,7 @@ def make_tensor_descriptor(self, base: TensorTy, shape: List[TensorTy], strides: # Check whether `block_shape` is static block_shape = tl._unwrap_shape(block_shape) - assert isinstance(base.type, tl.pointer_type) + assert isinstance(base.type, tl.pointer_type), f"base must be a pointer type, got {base.type}" type = tl.block_type(base.type.element_ty, block_shape) base_handle = base.handle is_signed_int = base.type.element_ty.is_int_signed()