diff --git a/basalt/autograd/ops/mlops.mojo b/basalt/autograd/ops/mlops.mojo index 0869919..29f5e39 100644 --- a/basalt/autograd/ops/mlops.mojo +++ b/basalt/autograd/ops/mlops.mojo @@ -4,6 +4,7 @@ from math.limit import min_finite, max_finite from basalt import Tensor, TensorShape from basalt.utils.tensorutils import elwise_transform +from basalt.utils.itertools import product from basalt.autograd.attributes import Attribute, AttributeVector @@ -491,4 +492,137 @@ struct SLICE: Self.slice_kernel[ug_shape, t1_shape, steps, starts, ends, True](res_grad, ug) - return res_grad ^ \ No newline at end of file + return res_grad ^ + + +struct INDEX: + @staticmethod + fn adjust_boundary(slice: Int, dim_size: Int) -> Int: + # Adjust negative indices & ensure they are within bounds. + var s = slice if slice >= 0 else dim_size + slice + return max(min(s, dim_size), 0) + + @staticmethod + fn to_indeces(shape: TensorShape, attrs: AttributeVector) -> List[List[Int]]: + var SLICE_LITERALS = List[StringLiteral]("dim_0s", "dim_1s", "dim_2s", "dim_3s", "dim_4s", "dim_5s", "dim_6s", "dim_7s") + var INDEX_LITERALS = List[StringLiteral]("dim_0i", "dim_1i", "dim_2i", "dim_3i", "dim_4i", "dim_5i", "dim_6i", "dim_7i") + + var rank = shape.rank() + var indeces = List[List[Int]]() + indeces.reserve(rank) + + for dim in range(rank): + var temp = List[Int]() + + # Option 1: Slice + if attrs[SLICE_LITERALS[dim]]: + var slice = attrs[SLICE_LITERALS[dim]].value().to_shape() + var step = slice[2] if slice.rank() == 3 else 1 + for i in range( + start=Self.adjust_boundary(slice[0], shape[dim]), + end=Self.adjust_boundary(slice[1], shape[dim]), + step=step + ): + temp.append(i) + + # Option 2: Indeces + elif attrs[INDEX_LITERALS[dim]]: + var indeces = attrs[INDEX_LITERALS[dim]].value().to_shape() + for i in range(indeces.rank()): + temp.append(indeces[i]) + + # All indeces + else: + for i in range(shape[dim]): + temp.append(i) + + indeces.append(temp) + + return indeces ^ + + @staticmethod + fn result_shape(shape: TensorShape, attrs: AttributeVector) -> TensorShape: + var indeces = Self.to_indeces(shape, attrs) + var rank = shape.rank() + var new_shape = List[Int]() + new_shape.reserve(rank) + for i in range(rank): + new_shape.append(len(indeces[i])) + return TensorShape(new_shape) + + @staticmethod + fn map_indeces[ + nelts: Int, + strides: TensorShape, + indeces: List[List[Int]], + ](idx: Int) -> SIMD[DType.int64, nelts]: + alias indeces_product = product(indeces) + + var temp = SIMD[DType.int64, nelts]() + for i in range(idx, idx + nelts): + var comb = indeces_product[i] + var flat_index = 0 + + for dim in range(len(comb)): + flat_index += comb[dim] * strides[dim] + + temp[i % nelts] = flat_index + + return temp + + @staticmethod + fn forward[ + t1_shape: TensorShape, + attributes: AttributeVector, + ](inout res: Tensor[dtype], t1: Tensor[dtype]): + alias indeces = Self.to_indeces(t1_shape, attributes) + alias strides = t1_shape.strides() + alias total_length = len(product(indeces)) + + @parameter + fn vec_index[nelts: Int](i: Int): + + res.store[nelts](i, + t1.data().gather(Self.map_indeces[nelts, strides, indeces](i)) + ) + + vectorize[vec_index, nelts](total_length) + + + @staticmethod + fn backward[ + ug_shape: TensorShape, + t1_shape: TensorShape, + attributes: AttributeVector = AttributeVector(), + ](ug: Tensor[dtype], t1: Tensor[dtype]) -> Tensor[dtype]: + alias indeces = Self.to_indeces(t1_shape, attributes) + alias strides = t1_shape.strides() + alias total_length = len(product(indeces)) + + var res_grad = Tensor[dtype](t1_shape) + + @parameter + fn vec_index[nelts: Int](i: Int): + + var offset = Self.map_indeces[nelts, strides, indeces](i) + + # res_grad.data().scatter( + # offset, + # res_grad.data().gather(offset) + ug.load[nelts](i), + # ) + + # NOTE: Scatter (reduce SUM) required + # When the offset = [0, 2, 4, 0] and ug = [1, 1, 1, 1] + # The standard scatter will overwrite the values with overlapping indices. + # It doesn't accumulate index 0 twice as it should be: res_grad[0] += 1 + 1 + # cfr. https://github.com/ml-explore/mlx/blob/main/mlx/backend/common/indexing.cpp#L256-L258 + # cfr. https://github.com/modularml/mojo/blob/main/stdlib/src/sys/intrinsics.mojo#L903 + + # Workaround + var u = ug.load[nelts](i) + for j in range(nelts): + res_grad[int(offset[j])] += u[j] + + vectorize[vec_index, nelts](total_length) + + return res_grad^ \ No newline at end of file diff --git a/basalt/autograd/ops/ops.mojo b/basalt/autograd/ops/ops.mojo index 7198270..c737821 100644 --- a/basalt/autograd/ops/ops.mojo +++ b/basalt/autograd/ops/ops.mojo @@ -15,7 +15,7 @@ from .basics import ( TRANSPOSE, FMA, ) -from .mlops import SIGMOID, RELU, TANH, CLIP, SQUEEZE, UNSQUEEZE, SLICE +from .mlops import SIGMOID, RELU, TANH, CLIP, SQUEEZE, UNSQUEEZE, SLICE, INDEX from .dynamics import CONCAT, SPLIT from .conv import CONV2D from .pool import MAXPOOL2D @@ -61,6 +61,7 @@ struct OP(Stringable): alias CONCAT = OP(23, "CONCAT", dynamic=True) alias SPLIT = OP(24, "SPLIT", dynamic=True) alias SLICE = OP(25, "SLICE") + alias INDEX = OP(26, "INDEX") var id: UInt8 var name: Bytes[16] @@ -135,6 +136,8 @@ fn static_result_shape( return UNSQUEEZE.result_shape(t1_shape, attributes) elif op == OP.SLICE: return SLICE.result_shape(t1_shape, attributes) + elif op == OP.INDEX: + return INDEX.result_shape(t1_shape, attributes) else: print("[ERROR] Operator not found.") return TensorShape(-1) @@ -249,6 +252,8 @@ fn forward_op[ UNSQUEEZE.forward[t1_shape, attributes](res, t1) elif op == OP.SLICE: SLICE.forward[t1_shape, attributes](res, t1) + elif op == OP.INDEX: + INDEX.forward[t1_shape, attributes](res, t1) else: print("[ERROR] Operator not found.") @@ -361,6 +366,8 @@ fn backward_op[ res_grad = UNSQUEEZE.backward[ug_shape, t1_shape](ug, t1) elif op == OP.SLICE: res_grad = SLICE.backward[ug_shape, t1_shape, attributes](ug, t1) + elif op == OP.INDEX: + res_grad = INDEX.backward[ug_shape, t1_shape, attributes](ug, t1) else: print("[ERROR] Operator not found.") res_grad = Tensor[dtype](-1) diff --git a/basalt/utils/itertools.mojo b/basalt/utils/itertools.mojo new file mode 100644 index 0000000..fd7a6ce --- /dev/null +++ b/basalt/utils/itertools.mojo @@ -0,0 +1,49 @@ + +@value +struct _ProductIterator(Sized): + var lists: List[List[Int]] + var _current: Int + var _iters: Int + + @always_inline("nodebug") + fn __init__(inout self, lists: List[List[Int]]): + self.lists = lists + self._current = 0 + + self._iters = 1 + for lst in self.lists: + self._iters *= len(lst[]) + + @always_inline("nodebug") + fn __len__(self) -> Int: + return self._iters + + @always_inline("nodebug") + fn __iter__(self) -> Self: + return self + + @always_inline("nodebug") + fn __next__(inout self) -> List[Int]: + self._current += 1 + self._iters -= 1 + return self._get_combination(self._current - 1) + + @always_inline("nodebug") + fn _get_combination(self, current: Int) -> List[Int]: + var combination = List[Int]() + var count = current + for i in reversed(range(len(self.lists))): + var index = count % len(self.lists[i]) + combination.append(self.lists[i][index]) + count //= len(self.lists[i]) + combination._reverse() + return combination ^ + + @always_inline("nodebug") + fn __getitem__(self, index: Int) -> List[Int]: + return self._get_combination(index) + + +@always_inline("nodebug") +fn product(lists: List[List[Int]]) -> _ProductIterator: + return _ProductIterator(lists) \ No newline at end of file diff --git a/tests/mojo/test_mlops.mojo b/tests/mojo/test_mlops.mojo index 2ba723e..4085d48 100644 --- a/tests/mojo/test_mlops.mojo +++ b/tests/mojo/test_mlops.mojo @@ -620,6 +620,59 @@ fn test_backward_SLICE_multiple_axes() raises: ](t1, ug, expected_ug) +fn test_INDEX() raises: + alias t1_shape = TensorShape(2, 3, 5) + var t = Tensor[dtype](t1_shape) + for i in range(t.num_elements()): + t[i] = i + + # t[:, [0, 0], 0:5:2] + # TODO: need for a list attribute as this only supports to specify indeces of MAX_RANK + alias attr_1 = Attribute("dim_1i", TensorShape(0, 0)) + alias attr_2 = Attribute("dim_2s", TensorShape(0, 5, 2)) + + var expected = Tensor[dtype](2, 2, 3) + for i in range(2): + for j in range(2): + for k in range(3): + expected[i*2*3 + j*3 + k] = i * 3 * 5 + k * 2 + + test_unary_op[ + OP.INDEX, t1_shape, AttributeVector( + attr_1, + attr_2, + ) + ](t, expected) + + +fn test_INDEX_backward() raises: + alias t1_shape = TensorShape(2, 3, 5) + var t = Tensor[dtype](t1_shape) + for i in range(t.num_elements()): + t[i] = i + + alias attr_1 = Attribute("dim_1i", TensorShape(0, 0)) + alias attr_2 = Attribute("dim_2s", TensorShape(0, 5, 2)) + + alias ug_shape = TensorShape(2, 2, 3) + var ug = Tensor[dtype](ug_shape) + fill(ug, 1.0) + + var expected = Tensor[dtype](t1_shape) + for i in range(2): + for j in range(2): + for k in range(3): + # NOTE: `+=` because selected indeces [0, 0] can repeat + expected[i * 3 * 5 + k * 2] += 1.0 + + test_unary_op_backward[ + OP.INDEX, t1_shape, ug_shape, AttributeVector( + attr_1, + attr_2, + ) + ](t, ug, expected) + + fn main(): try: test_SIGMOID() @@ -632,6 +685,7 @@ fn main(): test_SLICE_step() test_SLICE_neg() test_SLICE_multiple_axes() + test_INDEX() except e: print("[ERROR] Error in forward mlops") print(e) @@ -646,6 +700,7 @@ fn main(): test_backward_UNSQUEEZE() test_backward_SLICE() test_backward_SLICE_multiple_axes() + test_INDEX_backward() except e: print("[ERROR] Error in backward mlops") print(e)