Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/python-api/triton.language.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Shape Manipulation Ops
:nosignatures:

broadcast_to
expand_dims
reshape
ravel

Expand Down
80 changes: 80 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,86 @@ def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.con
assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all()


# ----------------
# test expand_dims
# ----------------
def test_expand_dims():
@triton.jit
def expand_dims_kernel(dummy, N: tl.constexpr):
offset1 = tl.arange(0, N)

t = tl.expand_dims(offset1, 0)
tl.static_assert(t.shape == [1, N])

t = tl.expand_dims(offset1, 1)
tl.static_assert(t.shape == [N, 1])

t = tl.expand_dims(offset1, -1)
tl.static_assert(t.shape == [N, 1])

t = tl.expand_dims(offset1, -2)
tl.static_assert(t.shape == [1, N])

t = tl.expand_dims(offset1, (0, -1))
tl.static_assert(t.shape == [1, N, 1])

t = tl.expand_dims(offset1, (0, 1, 3))
tl.static_assert(t.shape == [1, 1, N, 1])

t = tl.expand_dims(offset1, (-4, 2, -1))
tl.static_assert(t.shape == [1, N, 1, 1])

t = tl.expand_dims(offset1, (3, 1, 2))
tl.static_assert(t.shape == [N, 1, 1, 1])

N = 32
dummy_tensor = torch.empty((), device="cuda")
expand_dims_kernel[(1,)](dummy_tensor, N)


def test_expand_dims_error_cases():
@triton.jit
def dim_out_of_range1(dummy, N: tl.constexpr):
offset1 = tl.arange(0, N)

t = tl.expand_dims(offset1, -2)
t = tl.expand_dims(offset1, -3)

@triton.jit
def dim_out_of_range2(dummy, N: tl.constexpr):
offset1 = tl.arange(0, N)

t = tl.expand_dims(offset1, 1)
t = tl.expand_dims(offset1, 2)

@triton.jit
def duplicate_dim1(dummy, N: tl.constexpr):
offset1 = tl.arange(0, N)

t = tl.expand_dims(offset1, (0, 0))

@triton.jit
def duplicate_dim2(dummy, N: tl.constexpr):
offset1 = tl.arange(0, N)

t = tl.expand_dims(offset1, (0, -3))

N = 32
dummy_tensor = torch.empty((), device="cuda")

with pytest.raises(triton.CompilationError, match="invalid axis -3"):
dim_out_of_range1[(1,)](dummy_tensor, N)

with pytest.raises(triton.CompilationError, match="invalid axis 2"):
dim_out_of_range2[(1,)](dummy_tensor, N)

with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"):
duplicate_dim1[(1,)](dummy_tensor, N)

with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"):
duplicate_dim2[(1,)](dummy_tensor, N)


# ---------------
# test where
# ---------------
Expand Down
2 changes: 2 additions & 0 deletions python/triton/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
dot,
dtype,
exp,
expand_dims,
full,
fdiv,
float16,
Expand Down Expand Up @@ -130,6 +131,7 @@
"dot",
"dtype",
"exp",
"expand_dims",
"extra",
"fdiv",
"float16",
Expand Down
43 changes: 39 additions & 4 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from contextlib import contextmanager
from enum import Enum
from functools import wraps
from typing import Callable, List, TypeVar
from typing import Callable, List, Sequence, TypeVar

import triton
from . import semantic
Expand Down Expand Up @@ -883,6 +883,41 @@ def reshape(input, shape, _builder=None):
shape = _shape_check_impl(shape)
return semantic.reshape(input, shape, _builder)


def _wrap_axis(axis, ndim):
if not (-ndim <= axis < ndim):
raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}")

return axis if axis >= 0 else axis + ndim


@builtin
def expand_dims(input, axis, _builder=None):
"""
Expand the shape of a tensor, by inserting new length-1 dimensions.

Axis indices are with respect to the resulting tensor, so
``result.shape[axis]`` will be 1 for each axis.

:param input: The input tensor.
:type input: tl.tensor
:param axis: The indices to add new axes
:type axis: int | Sequence[int]

"""
axis = _constexpr_to_value(axis)
axes = list(axis) if isinstance(axis, Sequence) else [axis]
new_ndim = len(input.shape) + len(axes)
axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes]

if len(set(axes)) != len(axes):
raise ValueError(f"expand_dims recieved duplicate axes, normalized axes = {axes}")

ret = input
for a in sorted(axes):
ret = semantic.expand_dims(ret, a, _builder)
return ret

# -----------------------
# Linear Algebra
# -----------------------
Expand Down Expand Up @@ -1281,9 +1316,9 @@ def _argreduce(input, axis, combine_fn, _builder=None, _generator=None):

if len(input.shape) > 1:
# Broadcast index across the non-reduced axes
expand_dims_index = [constexpr(None)] * len(input.shape)
expand_dims_index[axis] = slice(None)
index = index.__getitem__(expand_dims_index, _builder=_builder)
axes_to_expand = [constexpr(d) for d in range(len(input.shape))]
del axes_to_expand[axis]
index = expand_dims(index, axes_to_expand, _builder=_builder)
index = broadcast_to(index, input.shape, _builder=_builder)

rvalue, rindices = reduce((input, index), axis, combine_fn,
Expand Down