Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 59 additions & 32 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,7 @@ def test_compare_op(dtype_x, dtype_y, op, mode_x, mode_y, num_ctas, device):
# test broadcast
# ---------------
@pytest.mark.parametrize("dtype", dtypes_with_bfloat16)
def test_broadcast(dtype):
def test_broadcast(dtype, device):
@triton.jit
def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr):
offset1 = tl.arange(0, M)
Expand All @@ -585,41 +585,42 @@ def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.con
y = numpy_random(N, dtype_str=dtype, rs=rs)
_, y_broadcasted_np = np.broadcast_arrays(x, y)

x_tri = to_triton(x, device='cuda', dst_type=dtype)
y_tri = to_triton(y, device='cuda', dst_type=dtype)
y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device='cuda', dst_type=dtype)
x_tri = to_triton(x, device=device, dst_type=dtype)
y_tri = to_triton(y, device=device, dst_type=dtype)
y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype)

broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()

# ----------
# test slice
# ----------


def test_slice(device):

# ---------------
# test broadcast
# ---------------
@pytest.mark.parametrize("dtype", dtypes_with_bfloat16)
def test_broadcast(dtype, device):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated change, but this test overwrites the test above and looks like it might have been a bad merge conflict resolution. I've removed the shadowed test that is identical other than hard coding device="cuda".

@triton.jit
def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.constexpr):
offset1 = tl.arange(0, M)
offset2 = tl.arange(0, N)
x = tl.load(x_ptr + N * offset1[:, None] + offset2[None, :])
y = tl.load(y_ptr + offset2)
_, y_broadcasted = tl.broadcast(x, y)
tl.store(y_broadcasted_ptr + N * offset1[:, None] + offset2[None, :], y_broadcasted)
def slice_kernel(XBLOCK: tl.constexpr):
data = tl.arange(0, XBLOCK)
tl.static_assert(data.shape == [XBLOCK])

M = 32
N = 64
rs = RandomState(17)
x = numpy_random((M, N), dtype_str=dtype, rs=rs)
y = numpy_random(N, dtype_str=dtype, rs=rs)
_, y_broadcasted_np = np.broadcast_arrays(x, y)
t = data[None, :]
tl.static_assert(t.shape == [1, XBLOCK])

x_tri = to_triton(x, device=device, dst_type=dtype)
y_tri = to_triton(y, device=device, dst_type=dtype)
y_broadcasted_tri = to_triton(np.empty((M, N), dtype=y_broadcasted_np.dtype), device=device, dst_type=dtype)
t = data[None, :, None]
tl.static_assert(t.shape == [1, XBLOCK, 1])

scalar = tl.full([], 1, tl.int32)
tl.static_assert(scalar.shape == [])

t = scalar[None]
tl.static_assert(t.shape == [1])

t = scalar[None, None]
tl.static_assert(t.shape == [1, 1])

slice_kernel[(1,)](XBLOCK=32)

broadcast_kernel[(1,)](x_tri, y_tri, y_broadcasted_tri, M=M, N=N)
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()

# ------------------
# test invalid slice
Expand Down Expand Up @@ -669,6 +670,14 @@ def expand_dims_kernel(dummy, N: tl.constexpr):
t = tl.expand_dims(offset1, (3, 1, 2))
tl.static_assert(t.shape == [N, 1, 1, 1])

scalar = tl.sum(offset1)
tl.static_assert(scalar.shape == [])
t = tl.expand_dims(scalar, 0)
tl.static_assert(t.shape == [1])

t = tl.expand_dims(scalar, -1)
tl.static_assert(t.shape == [1])

N = 32
dummy_tensor = torch.empty((), device=device)
expand_dims_kernel[(1,)](dummy_tensor, N)
Expand All @@ -689,6 +698,13 @@ def dim_out_of_range2(dummy, N: tl.constexpr):
t = tl.expand_dims(offset1, 1)
t = tl.expand_dims(offset1, 2)

@triton.jit
def dim_out_of_range3(dummy, N: tl.constexpr):
offset1 = tl.arange(0, 1)
scalar = tl.sum(offset1)

t = tl.expand_dims(scalar, 1)

@triton.jit
def duplicate_dim1(dummy, N: tl.constexpr):
offset1 = tl.arange(0, N)
Expand All @@ -710,6 +726,9 @@ def duplicate_dim2(dummy, N: tl.constexpr):
with pytest.raises(triton.CompilationError, match="invalid axis 2"):
dim_out_of_range2[(1,)](dummy_tensor, N)

with pytest.raises(triton.CompilationError, match="invalid axis 1"):
dim_out_of_range3[(1,)](dummy_tensor, N)

with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"):
duplicate_dim1[(1,)](dummy_tensor, N)

Expand Down Expand Up @@ -2467,7 +2486,8 @@ def kernel(Z, X, Y,


@pytest.mark.parametrize("dtype_str", int_dtypes + uint_dtypes + float_dtypes + ['bfloat16'])
def test_full(dtype_str, device):
@pytest.mark.parametrize("shape", [(), (1,), (128,)])
def test_full(dtype_str, shape, device):
if dtype_str in uint_dtypes and not hasattr(torch, dtype_str):
# PyTorch only has unsigned 8, but not 16, 32, or 64
dtype = getattr(torch, dtype_str[1:]) # uintx -> intx
Expand All @@ -2478,21 +2498,28 @@ def test_full(dtype_str, device):
@triton.jit
def kernel_static(out):
a = GENERATE_TEST_HERE
tl.static_assert(a.shape == SHAPE)
out_ptr = out + tl.arange(0, 128)[:]
tl.store(out_ptr, a)

@triton.jit
def kernel_dynamic(out, val, dtype: tl.constexpr):
a = tl.full((128,), val, dtype)
a = tl.full(SHAPE, val, dtype)
tl.static_assert(a.shape == SHAPE)
out_ptr = out + tl.arange(0, 128)[:]
tl.store(out_ptr, a)

kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"})
kernel_static_patched = patch_kernel(kernel_static, {
'GENERATE_TEST_HERE': f"tl.full({shape}, 2, tl.{dtype_str})",
'SHAPE': str(list(shape)),
})
out_static = torch.zeros((128), dtype=dtype, device=device)
kernel_static_patched[(1,)](out_static)
out_dynamic = torch.zeros((128), dtype=dtype, device=device)
kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
assert torch.all(out_static == 2)

kernel_dynamic_patched = patch_kernel(kernel_dynamic, {'SHAPE': str(list(shape))})
out_dynamic = torch.zeros((128), dtype=dtype, device=device)
kernel_dynamic_patched[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
assert torch.all(out_dynamic == 2)


Expand Down
6 changes: 2 additions & 4 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,9 +531,7 @@ def __init__(self, handle, type: dtype):
# IR handle
self.handle = handle
# Block shape
self.shape = (1, )
if type.is_block():
self.shape = type.shape
self.shape = type.shape if type.is_block() else ()
self.numel = 1
for s in self.shape:
self.numel *= s
Expand Down Expand Up @@ -743,7 +741,7 @@ def __not__(self, _builder=None):

@builtin
def __getitem__(self, slices, _builder=None):
if isinstance(slices, slice):
if isinstance(slices, (slice, constexpr)):
slices = [slices]
ret = self
for dim, sl in enumerate(slices):
Expand Down
26 changes: 18 additions & 8 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,25 +501,31 @@ def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.te
if isinstance(value, tl.tensor):
assert value.numel.value == 1, "only accepts size-1 tensor"
value = cast(value, dtype, builder)
ret_ty = tl.block_type(value.dtype, shape)
return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)
else:
# scalar
if dtype is None:
raise ValueError("dtype must be specified when value is not a tensor")
if value == 0:
value = builder.get_null_value(dtype.to_ir(builder))
else:
get_value_fn = getattr(builder, f"get_{dtype.name}")
value = get_value_fn(value)
if dtype is None:
raise ValueError("dtype must be specified when value is not a tensor")
ret_ty = tl.block_type(dtype, shape)
return tl.tensor(builder.create_splat(value, shape), ret_ty)
value = tl.tensor(value, dtype)

return splat(value, shape, builder)


# ===----------------------------------------------------------------------===//
# Shape Manipulation
# ===----------------------------------------------------------------------===//

def splat(value: tl.tensor, shape: List[int], builder: ir.builder) -> tl.tensor:
assert not value.type.is_block(), "Cannot splat a block tensor"
if len(shape) == 0:
return value
ret_ty = tl.block_type(value.dtype, shape)
return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)


def view(input: tl.tensor,
dst_shape: List[int],
Expand All @@ -544,8 +550,12 @@ def reshape(input: tl.tensor,


def expand_dims(input: tl.tensor, axis: int, builder: ir.builder) -> tl.tensor:
dst_shape = list(input.type.shape)
dst_shape = [tl._constexpr_to_value(x) for x in input.shape]
dst_shape.insert(axis, 1)

if not input.type.is_block():
return splat(input, shape=dst_shape, builder=builder)

ret_ty = tl.block_type(input.type.scalar, dst_shape)
return tl.tensor(builder.create_expand_dims(input.handle, axis), ret_ty)

Expand Down Expand Up @@ -1504,7 +1514,7 @@ def abs(x: tl.tensor, builder: ir.builder) -> tl.tensor:


def multiple_of(x: tl.tensor, values: List[int]) -> tl.tensor:
if len(x.shape) != len(values):
if max(1, len(x.shape)) != len(values):
raise ValueError("Shape of input to multiple_of does not match the length of values")
x.handle.set_attr("tt.divisibility", ir.make_attr(values, x.handle.get_context()))
return x
Expand Down