From 03abb7663fb1e7148617e8fa70bbfce74610888f Mon Sep 17 00:00:00 2001 From: Peter Bell Date: Wed, 3 May 2023 19:02:10 +0100 Subject: [PATCH] [FRONTEND] Add tl.expand_dims This exposes `semantic.expand_dims` in the public API and builds upon it with support for expanding multiple dimensions at once. e.g. ```python tl.expand_dims(tl.arange(0, N), (0, -1)) # shape = [1, N, 1] ``` Compared to indexing with `None`, this API is useful because the dimensions can be constexpr values rather than hard-coded into the source. --- docs/python-api/triton.language.rst | 1 + python/test/unit/language/test_core.py | 80 ++++++++++++++++++++++++++ python/triton/language/__init__.py | 2 + python/triton/language/core.py | 43 ++++++++++++-- 4 files changed, 122 insertions(+), 4 deletions(-) diff --git a/docs/python-api/triton.language.rst b/docs/python-api/triton.language.rst index 58bce15f3f7b..5013a0242a60 100644 --- a/docs/python-api/triton.language.rst +++ b/docs/python-api/triton.language.rst @@ -34,6 +34,7 @@ Shape Manipulation Ops :nosignatures: broadcast_to + expand_dims reshape ravel diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 4ea448e450d8..0fd00dae525c 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -457,6 +457,86 @@ def broadcast_kernel(x_ptr, y_ptr, y_broadcasted_ptr, M: tl.constexpr, N: tl.con assert (y_broadcasted_np == to_numpy(y_broadcasted_tri)).all() +# ---------------- +# test expand_dims +# ---------------- +def test_expand_dims(): + @triton.jit + def expand_dims_kernel(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, 0) + tl.static_assert(t.shape == [1, N]) + + t = tl.expand_dims(offset1, 1) + tl.static_assert(t.shape == [N, 1]) + + t = tl.expand_dims(offset1, -1) + tl.static_assert(t.shape == [N, 1]) + + t = tl.expand_dims(offset1, -2) + tl.static_assert(t.shape == [1, N]) + + t = tl.expand_dims(offset1, (0, -1)) + tl.static_assert(t.shape == [1, N, 1]) + + t = tl.expand_dims(offset1, (0, 1, 3)) + tl.static_assert(t.shape == [1, 1, N, 1]) + + t = tl.expand_dims(offset1, (-4, 2, -1)) + tl.static_assert(t.shape == [1, N, 1, 1]) + + t = tl.expand_dims(offset1, (3, 1, 2)) + tl.static_assert(t.shape == [N, 1, 1, 1]) + + N = 32 + dummy_tensor = torch.empty((), device="cuda") + expand_dims_kernel[(1,)](dummy_tensor, N) + + +def test_expand_dims_error_cases(): + @triton.jit + def dim_out_of_range1(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, -2) + t = tl.expand_dims(offset1, -3) + + @triton.jit + def dim_out_of_range2(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, 1) + t = tl.expand_dims(offset1, 2) + + @triton.jit + def duplicate_dim1(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, (0, 0)) + + @triton.jit + def duplicate_dim2(dummy, N: tl.constexpr): + offset1 = tl.arange(0, N) + + t = tl.expand_dims(offset1, (0, -3)) + + N = 32 + dummy_tensor = torch.empty((), device="cuda") + + with pytest.raises(triton.CompilationError, match="invalid axis -3"): + dim_out_of_range1[(1,)](dummy_tensor, N) + + with pytest.raises(triton.CompilationError, match="invalid axis 2"): + dim_out_of_range2[(1,)](dummy_tensor, N) + + with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"): + duplicate_dim1[(1,)](dummy_tensor, N) + + with pytest.raises(triton.CompilationError, match=r"duplicate axes, normalized axes = \[0, 0\]"): + duplicate_dim2[(1,)](dummy_tensor, N) + + # --------------- # test where # --------------- diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 92619bf27b7c..7485f374b9e9 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -39,6 +39,7 @@ dot, dtype, exp, + expand_dims, full, fdiv, float16, @@ -130,6 +131,7 @@ "dot", "dtype", "exp", + "expand_dims", "extra", "fdiv", "float16", diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 3230fa718848..f9e8ad3b91c0 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -3,7 +3,7 @@ from contextlib import contextmanager from enum import Enum from functools import wraps -from typing import Callable, List, TypeVar +from typing import Callable, List, Sequence, TypeVar import triton from . import semantic @@ -883,6 +883,41 @@ def reshape(input, shape, _builder=None): shape = _shape_check_impl(shape) return semantic.reshape(input, shape, _builder) + +def _wrap_axis(axis, ndim): + if not (-ndim <= axis < ndim): + raise ValueError(f"invalid axis {axis}. Expected {-ndim} <= axis < {ndim}") + + return axis if axis >= 0 else axis + ndim + + +@builtin +def expand_dims(input, axis, _builder=None): + """ + Expand the shape of a tensor, by inserting new length-1 dimensions. + + Axis indices are with respect to the resulting tensor, so + ``result.shape[axis]`` will be 1 for each axis. + + :param input: The input tensor. + :type input: tl.tensor + :param axis: The indices to add new axes + :type axis: int | Sequence[int] + + """ + axis = _constexpr_to_value(axis) + axes = list(axis) if isinstance(axis, Sequence) else [axis] + new_ndim = len(input.shape) + len(axes) + axes = [_wrap_axis(_constexpr_to_value(d), new_ndim) for d in axes] + + if len(set(axes)) != len(axes): + raise ValueError(f"expand_dims recieved duplicate axes, normalized axes = {axes}") + + ret = input + for a in sorted(axes): + ret = semantic.expand_dims(ret, a, _builder) + return ret + # ----------------------- # Linear Algebra # ----------------------- @@ -1281,9 +1316,9 @@ def _argreduce(input, axis, combine_fn, _builder=None, _generator=None): if len(input.shape) > 1: # Broadcast index across the non-reduced axes - expand_dims_index = [constexpr(None)] * len(input.shape) - expand_dims_index[axis] = slice(None) - index = index.__getitem__(expand_dims_index, _builder=_builder) + axes_to_expand = [constexpr(d) for d in range(len(input.shape))] + del axes_to_expand[axis] + index = expand_dims(index, axes_to_expand, _builder=_builder) index = broadcast_to(index, input.shape, _builder=_builder) rvalue, rindices = reduce((input, index), axis, combine_fn,