Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3250,7 +3250,7 @@ def selu(x, alpha):
:return: one tensor or a tuple of tensors, depending on the mapped function.
'''
# Build the block for the nested region first to discover the return types
assert pack >= 1
assert pack >= 1, f"pack must be >= 1, got {pack}"
in_scalar_tys = [t.type.scalar for t in args]
builder = _semantic.builder
block = builder.new_block()
Expand Down Expand Up @@ -3882,8 +3882,8 @@ def builtin_max(*args, propagate_nan=_NOTHING, _semantic=None):
is_constexpr = all(not isinstance(x, base_value) for x in args)
if is_constexpr:
assert propagate_nan is _NOTHING, "propagate_nan is not supported on builtin max"
assert not any(math.isnan(x) for x in args)
assert not any(is_negative_zero(x) for x in args)
assert not any(math.isnan(x) for x in args), "constexpr max does not support NaN values"
assert not any(is_negative_zero(x) for x in args), "constexpr max does not support negative zero"
return constexpr(builtins.max(_unwrap_if_constexpr(args)))

if propagate_nan is _NOTHING:
Expand All @@ -3906,8 +3906,8 @@ def builtin_min(*args, propagate_nan=_NOTHING, _semantic=None):
is_constexpr = all(not isinstance(x, base_value) for x in args)
if is_constexpr:
assert propagate_nan is _NOTHING, "propagate_nan is not supported on builtin min"
assert not any(math.isnan(x) for x in args)
assert not any(is_negative_zero(x) for x in args)
assert not any(math.isnan(x) for x in args), "constexpr min does not support NaN values"
assert not any(is_negative_zero(x) for x in args), "constexpr min does not support negative zero"
return constexpr(builtins.min(_unwrap_if_constexpr(args)))

if propagate_nan is _NOTHING:
Expand Down
49 changes: 32 additions & 17 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ def binary_op_sanitize_overflow_impl(self, lhs: TensorTy, rhs: TensorTy, binary_
return
lhs_sca_ty = lhs.type.scalar
rhs_sca_ty = rhs.type.scalar
assert lhs_sca_ty == rhs_sca_ty
assert lhs_sca_ty.is_int()
assert lhs_sca_ty == rhs_sca_ty, f"expected matching operand types, got {lhs_sca_ty} and {rhs_sca_ty}"
assert lhs_sca_ty.is_int(), f"expected integer type, got {lhs_sca_ty}"
lhs = self.cast(lhs, tl.int64)
rhs = self.cast(rhs, tl.int64)
ret = binary_op(lhs, rhs, False)
Expand Down Expand Up @@ -657,7 +657,7 @@ def expand_dims(self, input: TensorTy, axis: int) -> TensorTy:

def cat(self, lhs: TensorTy, rhs: TensorTy, can_reorder: bool) -> TensorTy:
assert can_reorder, "current implementation of `cat` always may reorder elements"
assert len(lhs.shape) == 1
assert len(lhs.shape) == 1, f"expected 1D input for cat, got {len(lhs.shape)}D"
ret_type = tl.block_type(lhs.type.scalar, [lhs.shape[0] + rhs.shape[0]])
return self.tensor(self.builder.create_cat(lhs.handle, rhs.handle), ret_type)

Expand Down Expand Up @@ -686,8 +686,9 @@ def join(self, a: TensorTy, b: TensorTy) -> TensorTy:
return ret

def split(self, a: TensorTy) -> Tuple[TensorTy, TensorTy]:
assert (len(a.shape) > 0)
assert (tl._unwrap_if_constexpr(a.shape[-1]) == 2)
assert (len(a.shape) > 0), "split requires a non-scalar tensor"
assert (tl._unwrap_if_constexpr(a.shape[-1]) == 2), \
f"expected last dimension to be 2 for split, got {tl._unwrap_if_constexpr(a.shape[-1])}"

new_shape = a.shape[:-1]
ret_type = tl.block_type(a.type.scalar, new_shape)
Expand Down Expand Up @@ -754,7 +755,8 @@ def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy:
tl.block_type(rhs_ty.scalar, [1] + rhs_shape.values))
rhs_ty = rhs.type
rhs_shape = rhs_ty.get_block_shapes()
assert len(rhs_shape) == len(lhs_shape)
assert len(rhs_shape) == len(lhs_shape), \
f"expected tensors of equal rank for broadcast, got {len(lhs_shape)} and {len(rhs_shape)}"

ret_shape = []
for i, left in enumerate(lhs_shape):
Expand Down Expand Up @@ -1062,7 +1064,8 @@ def load(self, ptr: TensorTy, mask: Optional[TensorTy], other: Optional[TensorTy

def descriptor_load(self, desc: tl.tensor_descriptor_base, offsets, cache_modifier: str,
eviction_policy: str) -> TensorTy:
assert isinstance(desc, tl.tensor_descriptor_base)
assert isinstance(desc, tl.tensor_descriptor_base), \
f"expected a tensor descriptor, got {type(desc).__name__}"
ndim = len(desc.block_shape)
assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"

Expand All @@ -1072,10 +1075,12 @@ def descriptor_load(self, desc: tl.tensor_descriptor_base, offsets, cache_modifi
return self.tensor(x, desc.block_type)

def validate_store_like(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> None:
assert isinstance(desc, tl.tensor_descriptor_base)
assert isinstance(desc, tl.tensor_descriptor_base), \
f"expected a tensor descriptor, got {type(desc).__name__}"
ndim = len(desc.block_shape)
assert len(offsets) == ndim, f"expected {ndim} offsets, but got {len(offsets)}"
assert value.shape == desc.block_shape
assert value.shape == desc.block_shape, \
f"expected value shape {desc.block_shape}, got {value.shape}"

def descriptor_store(self, desc: tl.tensor_descriptor_base, value: TensorTy, offsets) -> TensorTy:
self.validate_store_like(desc, value, offsets)
Expand Down Expand Up @@ -1136,7 +1141,8 @@ def descriptor_atomic_xor(self, desc: tl.tensor_descriptor_base, value: TensorTy
return self.tensor(self.builder.create_descriptor_reduce(kind, desc.handle, value.handle, offsets), tl.void)

def descriptor_gather(self, desc, x_offsets, y_offset, cache_modifier: str, eviction_policy: str) -> TensorTy:
assert isinstance(desc, tl.tensor_descriptor_base)
assert isinstance(desc, tl.tensor_descriptor_base), \
f"expected a tensor descriptor, got {desc.__class__.__name__}"
assert cache_modifier == "", "cache modifier is not supported yet"
assert eviction_policy == "", "eviction policy is not supported yet"

Expand All @@ -1162,7 +1168,8 @@ def descriptor_gather(self, desc, x_offsets, y_offset, cache_modifier: str, evic
return self.tensor(x, type)

def descriptor_scatter(self, desc, value: TensorTy, x_offsets, y_offset) -> TensorTy:
assert isinstance(desc, tl.tensor_descriptor_base)
assert isinstance(desc, tl.tensor_descriptor_base), \
f"expected a tensor descriptor, got {type(desc).__name__}"

# Validate descriptor.
assert len(desc.block_shape) == 2, f"descriptor must be 2D, but got {desc.block_shape}"
Expand Down Expand Up @@ -1421,7 +1428,7 @@ def _str_to_dot_input_precision(self, input_precision):

def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Optional[str],
max_num_imprecise_acc: int, out_dtype: tl.dtype | None) -> TensorTy:
assert lhs.type.is_block() and rhs.type.is_block()
assert lhs.type.is_block() and rhs.type.is_block(), "dot operands must be block tensors (not scalars)"

if lhs.dtype.is_fp8() and rhs.dtype.is_fp8():
# All combinations of supported fp8 x fp8 are permitted
Expand Down Expand Up @@ -1500,7 +1507,11 @@ def dot(self, lhs: TensorTy, rhs: TensorTy, acc: TensorTy, input_precision: Opti
acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0)
else:
acc_handle = acc.handle
assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype
assert acc.type.shape == ret_ty.shape, \
f"expected accumulator shape {ret_ty.shape}, got {acc.type.shape}"
assert acc.type.element_ty == out_dtype, \
f"expected accumulator dtype {out_dtype}, got {acc.type.element_ty}; " \
f"pass out_dtype={acc.type.element_ty} to use this accumulator dtype"

# max_num_imprecise_acc only applies to fp8 -> fp32 dot on sm_90
if max_num_imprecise_acc is None:
Expand Down Expand Up @@ -1565,7 +1576,7 @@ def verify_scaled_shape(self, M, N, K, lhs_scale, rhs_scale, scale_factor):
def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: TensorTy,
rhs_scale: Optional[TensorTy], rhs_format: str, acc: TensorTy | None, fast_math: bool,
lhs_k_pack: bool, rhs_k_pack: bool, out_dtype: tl.dtype) -> TensorTy:
assert lhs.type.is_block() and rhs.type.is_block()
assert lhs.type.is_block() and rhs.type.is_block(), "dot_scaled operands must be block tensors (not scalars)"
#TODO: validate types.
lhs_rank = len(lhs.shape)
rhs_rank = len(rhs.shape)
Expand Down Expand Up @@ -1604,7 +1615,11 @@ def dot_scaled(self, lhs: TensorTy, lhs_scale: TensorTy, lhs_format: str, rhs: T
acc_handle = self.builder.create_splat(ret_ty.to_ir(self.builder), _0)
else:
acc_handle = acc.handle
assert acc.type.shape == ret_ty.shape and acc.type.element_ty == out_dtype
assert acc.type.shape == ret_ty.shape, \
f"expected accumulator shape {ret_ty.shape}, got {acc.type.shape}"
assert acc.type.element_ty == out_dtype, \
f"expected accumulator dtype {out_dtype}, got {acc.type.element_ty}; " \
f"pass out_dtype={acc.type.element_ty} to use this accumulator dtype"
rhs_scale_handle = None if rhs_scale_is_none else rhs_scale.handle
lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle

Expand Down Expand Up @@ -1842,7 +1857,7 @@ def make_tensor_descriptor(self, base: TensorTy, shape: List[TensorTy], strides:
raise ValueError(f"Expected {ndim} strides but got {len(strides)}")
if len(block_shape) != ndim:
raise ValueError(f"Expected block_shape to have {ndim} dimensions but got {len(strides)}")
assert isinstance(base.dtype, tl.pointer_type)
assert isinstance(base.dtype, tl.pointer_type), f"base must be a pointer type, got {base.dtype}"
elem_size = base.dtype.element_ty.primitive_bitwidth // 8
contig_dim_size = tl._unwrap_if_constexpr(block_shape[-1])
if contig_dim_size * elem_size < 16:
Expand All @@ -1860,7 +1875,7 @@ def make_tensor_descriptor(self, base: TensorTy, shape: List[TensorTy], strides:
# Check whether `block_shape` is static
block_shape = tl._unwrap_shape(block_shape)

assert isinstance(base.type, tl.pointer_type)
assert isinstance(base.type, tl.pointer_type), f"base must be a pointer type, got {base.type}"
type = tl.block_type(base.type.element_ty, block_shape)
base_handle = base.handle
is_signed_int = base.type.element_ty.is_int_signed()
Expand Down
Loading