diff --git a/numojo/core/ndarray.mojo b/numojo/core/ndarray.mojo index 1adb505a..b4c73d10 100644 --- a/numojo/core/ndarray.mojo +++ b/numojo/core/ndarray.mojo @@ -3458,6 +3458,104 @@ struct NDArray[dtype: DType = DType.float64]( ) self._buf.ptr.store(_get_offset(indices, self.strides), item) + fn iter_by_axis[ + forward: Bool = True + ](self, axis: Int) raises -> _NDAxisIter[__origin_of(self), dtype, forward]: + """ + Returns an iterator yielding 1-d array by axis. + + Parameters: + forward: If True, iterate from the beginning to the end. + If False, iterate from the end to the beginning. + + Args: + axis: Axis by which the iteration is performed. + + Returns: + An iterator yielding 1-d array by axis. + + Example: + ```mojo + from numojo.prelude import * + var a = nm.arange[i8](24).reshape(Shape(2, 3, 4)) + print(a) + for i in a.iter_by_axis(axis=0): + print(String(i)) + ``` + + This prints: + + ```console + [[[ 0 1 2 3] + [ 4 5 6 7] + [ 8 9 10 11]] + [[12 13 14 15] + [16 17 18 19] + [20 21 22 23]]] + 3D-array Shape(2,3,4) Strides(12,4,1) DType: i8 C-cont: True F-cont: False own data: True + [ 0 12] + [ 1 13] + [ 2 14] + [ 3 15] + [ 4 16] + [ 5 17] + [ 6 18] + [ 7 19] + [ 8 20] + [ 9 21] + [10 22] + [11 23] + ``` + + Another example: + + ```mojo + from numojo.prelude import * + var a = nm.arange[i8](24).reshape(Shape(2, 3, 4)) + print(a) + for i in a.iter_by_axis(axis=2): + print(String(i)) + ``` + + This prints: + + ```console + [[[ 0 1 2 3] + [ 4 5 6 7] + [ 8 9 10 11]] + [[12 13 14 15] + [16 17 18 19] + [20 21 22 23]]] + 3D-array Shape(2,3,4) Strides(12,4,1) DType: i8 C-cont: True F-cont: False own data: True + [0 1 2 3] + [4 5 6 7] + [ 8 9 10 11] + [12 13 14 15] + [16 17 18 19] + [20 21 22 23] + ```. + """ + + var normalized_axis: Int = axis + if normalized_axis < 0: + normalized_axis += self.ndim + if (normalized_axis >= self.ndim) or (normalized_axis < 0): + raise Error( + String( + "\nError in `NDArray.iter_by_axis()`: " + "Axis ({}) is not in valid range [{}, {})." + ).format(axis, -self.ndim, self.ndim) + ) + + return _NDAxisIter[__origin_of(self), dtype, forward]( + ptr=self._buf.ptr, + axis=normalized_axis, + size=self.size, + ndim=self.ndim, + shape=self.shape, + strides=self.strides, + ) + fn max(self, axis: Int = 0) raises -> Self: """ Max on axis. @@ -3995,10 +4093,164 @@ struct _NDArrayIter[ return self.index +@value +struct _NDAxisIter[ + is_mutable: Bool, //, + origin: Origin[is_mutable], + dtype: DType, + forward: Bool = True, +](): + # TODO: + # 1. Use `length` (`index`) instead of `size` (`offset`) for the + # length (counter) of the iterator. + # 2. Return a view instead of copy if possible (when Bufferable is supported). + # 3. Add an argument in `__init__()` to specify the starting offset or index. + """ + An iterator yielding 1-d array by axis. + The yielded array is garanteed to be contiguous on memory, + and it is a view of the original array if possible. + + It can be used when a function reduces the dimension of the array by axis. + + Parameters: + is_mutable: Whether the iterator is mutable. + origin: The lifetime of the underlying NDArray data. + dtype: The data type of the item. + forward: The iteration direction. `False` is backwards. + + Example: + ``` + [[[ 0, 1, 2, 3], + [ 4, 5, 6, 7], + [ 8, 9, 10, 11]], + [[12, 13, 14, 15], + [16, 17, 18, 19], + [20, 21, 22, 23]]] + ``` + The above array is of shape (2,3,3). Itering by `axis=0` returns: + ``` + [0, 12], [1, 13], [2, 14], [3, 15], + [4, 16], [5, 17], [6, 18], [7, 19], + [8, 20], [9, 21], [10, 22], [11, 23] + ``` + Itering by `axis=1` returns: + ``` + [0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11], + [12, 16, 20], [13, 17, 21], [14, 18, 22], [15, 19, 23] + ``` + """ + + var ptr: UnsafePointer[Scalar[dtype]] + var axis: Int + var size: Int + var ndim: Int + var shape: NDArrayShape + var strides: NDArrayStrides + """Strides of array or view. It is not necessarily compatible with shape.""" + var strides_by_axis: NDArrayStrides + """Strides by axis according to shape of view.""" + var offset: Int + var size_of_res: Int + """Size of the result 1-d array.""" + + fn __init__( + out self, + ptr: UnsafePointer[Scalar[dtype]], + axis: Int, + size: Int, + ndim: Int, + shape: NDArrayShape, + strides: NDArrayStrides, + ) raises: + """ + Initialize the iterator. + + Args: + ptr: Pointer to the data buffer. + axis: Axis. + size: Size of the axis. + ndim: Number of dimensions. + shape: Shape of the array. + strides: Strides of array or view. It is not necessarily compatible with shape. + """ + if axis < 0 or axis >= ndim: + raise Error("Axis must be in the range of [0, ndim).") + + self.size_of_res = shape[axis] + self.offset = 0 if forward else size - self.size_of_res + self.ptr = ptr + self.axis = axis + self.size = size + self.ndim = ndim + self.shape = shape + self.strides = strides + self.strides_by_axis = NDArrayStrides(ndim=self.ndim, initialized=False) + var temp = 1 + (self.strides_by_axis._buf + axis).init_pointee_copy(temp) + temp *= shape[axis] + for i in range(self.ndim - 1, -1, -1): + if i != axis: + (self.strides_by_axis._buf + i).init_pointee_copy(temp) + temp *= shape[i] + + fn __has_next__(self) -> Bool: + @parameter + if forward: + return self.offset < self.size + else: + return self.offset > 0 - self.size_of_res + + fn __iter__(self) -> Self: + return self + + fn __len__(self) -> Int: + @parameter + if forward: + return (self.size - self.offset) // self.size_of_res + else: + return self.offset // self.size_of_res + 1 + + fn __next__(mut self) raises -> NDArray[dtype]: + var res = NDArray[dtype](Shape(self.size_of_res)) + var current_offset = self.offset + + @parameter + if forward: + self.offset += self.size_of_res + else: + self.offset -= self.size_of_res + + var remainder = current_offset + var item = Item(ndim=self.ndim, initialized=True) + + for i in range(self.axis): + item[i], remainder = divmod(remainder, self.strides_by_axis[i]) + + for i in range(self.axis + 1, self.ndim): + item[i], remainder = divmod(remainder, self.strides_by_axis[i]) + + item[self.axis], remainder = divmod( + remainder, self.strides_by_axis[self.axis] + ) + + for j in range(self.size_of_res): + (res._buf.ptr + j).init_pointee_copy( + self.ptr[_get_offset(item, self.strides)] + ) + item[self.axis] += 1 + + return res^ + + @value struct _NDIter[ is_mutable: Bool, //, origin: Origin[is_mutable], dtype: DType ](): + # TODO: Combine into `_NDAxisIter` with `axis=ndim-1`. + """ + An iterator yielding the array elements according to the order. + """ + var ptr: UnsafePointer[Scalar[dtype]] var length: Int var ndim: Int diff --git a/tests/core/test_array_methods.mojo b/tests/core/test_array_methods.mojo index 6bc42307..3848617a 100644 --- a/tests/core/test_array_methods.mojo +++ b/tests/core/test_array_methods.mojo @@ -1,5 +1,4 @@ -import numojo as nm -from numojo import * +from numojo.prelude import * from testing.testing import assert_true, assert_almost_equal, assert_equal from utils_for_test import check, check_is_close @@ -58,3 +57,17 @@ def test_constructors(): arr5.shape[2] == 5, "NDArray constructor with NDArrayShape: shape element 2", ) + + +def test_iterator(): + var a = nm.arange[i8](24).reshape(Shape(2, 3, 4)) + var a_iter = a.iter_by_axis[forward=False](axis=0) + var b = a_iter.__next__() == nm.array[i8]("[11, 23]") + assert_true( + b.item(0) == True, + "`_NDAxisIter` breaks", + ) + assert_true( + b.item(1) == True, + "`_NDAxisIter` breaks", + )