Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 252 additions & 0 deletions numojo/core/ndarray.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -3458,6 +3458,104 @@ struct NDArray[dtype: DType = DType.float64](
)
self._buf.ptr.store(_get_offset(indices, self.strides), item)

fn iter_by_axis[
forward: Bool = True
](self, axis: Int) raises -> _NDAxisIter[__origin_of(self), dtype, forward]:
"""
Returns an iterator yielding 1-d array by axis.

Parameters:
forward: If True, iterate from the beginning to the end.
If False, iterate from the end to the beginning.

Args:
axis: Axis by which the iteration is performed.

Returns:
An iterator yielding 1-d array by axis.

Example:
```mojo
from numojo.prelude import *
var a = nm.arange[i8](24).reshape(Shape(2, 3, 4))
print(a)
for i in a.iter_by_axis(axis=0):
print(String(i))
```

This prints:

```console
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]
3D-array Shape(2,3,4) Strides(12,4,1) DType: i8 C-cont: True F-cont: False own data: True
[ 0 12]
[ 1 13]
[ 2 14]
[ 3 15]
[ 4 16]
[ 5 17]
[ 6 18]
[ 7 19]
[ 8 20]
[ 9 21]
[10 22]
[11 23]
```

Another example:

```mojo
from numojo.prelude import *
var a = nm.arange[i8](24).reshape(Shape(2, 3, 4))
print(a)
for i in a.iter_by_axis(axis=2):
print(String(i))
```

This prints:

```console
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]
3D-array Shape(2,3,4) Strides(12,4,1) DType: i8 C-cont: True F-cont: False own data: True
[0 1 2 3]
[4 5 6 7]
[ 8 9 10 11]
[12 13 14 15]
[16 17 18 19]
[20 21 22 23]
```.
"""

var normalized_axis: Int = axis
if normalized_axis < 0:
normalized_axis += self.ndim
if (normalized_axis >= self.ndim) or (normalized_axis < 0):
raise Error(
String(
"\nError in `NDArray.iter_by_axis()`: "
"Axis ({}) is not in valid range [{}, {})."
).format(axis, -self.ndim, self.ndim)
)

return _NDAxisIter[__origin_of(self), dtype, forward](
ptr=self._buf.ptr,
axis=normalized_axis,
size=self.size,
ndim=self.ndim,
shape=self.shape,
strides=self.strides,
)

fn max(self, axis: Int = 0) raises -> Self:
"""
Max on axis.
Expand Down Expand Up @@ -3995,10 +4093,164 @@ struct _NDArrayIter[
return self.index


@value
struct _NDAxisIter[
is_mutable: Bool, //,
origin: Origin[is_mutable],
dtype: DType,
forward: Bool = True,
]():
# TODO:
# 1. Use `length` (`index`) instead of `size` (`offset`) for the
# length (counter) of the iterator.
# 2. Return a view instead of copy if possible (when Bufferable is supported).
# 3. Add an argument in `__init__()` to specify the starting offset or index.
"""
An iterator yielding 1-d array by axis.
The yielded array is garanteed to be contiguous on memory,
and it is a view of the original array if possible.

It can be used when a function reduces the dimension of the array by axis.

Parameters:
is_mutable: Whether the iterator is mutable.
origin: The lifetime of the underlying NDArray data.
dtype: The data type of the item.
forward: The iteration direction. `False` is backwards.

Example:
```
[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]]
```
The above array is of shape (2,3,3). Itering by `axis=0` returns:
```
[0, 12], [1, 13], [2, 14], [3, 15],
[4, 16], [5, 17], [6, 18], [7, 19],
[8, 20], [9, 21], [10, 22], [11, 23]
```
Itering by `axis=1` returns:
```
[0, 4, 8], [1, 5, 9], [2, 6, 10], [3, 7, 11],
[12, 16, 20], [13, 17, 21], [14, 18, 22], [15, 19, 23]
```
"""

var ptr: UnsafePointer[Scalar[dtype]]
var axis: Int
var size: Int
var ndim: Int
var shape: NDArrayShape
var strides: NDArrayStrides
"""Strides of array or view. It is not necessarily compatible with shape."""
var strides_by_axis: NDArrayStrides
"""Strides by axis according to shape of view."""
var offset: Int
var size_of_res: Int
"""Size of the result 1-d array."""

fn __init__(
out self,
ptr: UnsafePointer[Scalar[dtype]],
axis: Int,
size: Int,
ndim: Int,
shape: NDArrayShape,
strides: NDArrayStrides,
) raises:
"""
Initialize the iterator.

Args:
ptr: Pointer to the data buffer.
axis: Axis.
size: Size of the axis.
ndim: Number of dimensions.
shape: Shape of the array.
strides: Strides of array or view. It is not necessarily compatible with shape.
"""
if axis < 0 or axis >= ndim:
raise Error("Axis must be in the range of [0, ndim).")

self.size_of_res = shape[axis]
self.offset = 0 if forward else size - self.size_of_res
self.ptr = ptr
self.axis = axis
self.size = size
self.ndim = ndim
self.shape = shape
self.strides = strides
self.strides_by_axis = NDArrayStrides(ndim=self.ndim, initialized=False)
var temp = 1
(self.strides_by_axis._buf + axis).init_pointee_copy(temp)
temp *= shape[axis]
for i in range(self.ndim - 1, -1, -1):
if i != axis:
(self.strides_by_axis._buf + i).init_pointee_copy(temp)
temp *= shape[i]

fn __has_next__(self) -> Bool:
@parameter
if forward:
return self.offset < self.size
else:
return self.offset > 0 - self.size_of_res

fn __iter__(self) -> Self:
return self

fn __len__(self) -> Int:
@parameter
if forward:
return (self.size - self.offset) // self.size_of_res
else:
return self.offset // self.size_of_res + 1

fn __next__(mut self) raises -> NDArray[dtype]:
var res = NDArray[dtype](Shape(self.size_of_res))
var current_offset = self.offset

@parameter
if forward:
self.offset += self.size_of_res
else:
self.offset -= self.size_of_res

var remainder = current_offset
var item = Item(ndim=self.ndim, initialized=True)

for i in range(self.axis):
item[i], remainder = divmod(remainder, self.strides_by_axis[i])

for i in range(self.axis + 1, self.ndim):
item[i], remainder = divmod(remainder, self.strides_by_axis[i])

item[self.axis], remainder = divmod(
remainder, self.strides_by_axis[self.axis]
)

for j in range(self.size_of_res):
(res._buf.ptr + j).init_pointee_copy(
self.ptr[_get_offset(item, self.strides)]
)
item[self.axis] += 1

return res^


@value
struct _NDIter[
is_mutable: Bool, //, origin: Origin[is_mutable], dtype: DType
]():
# TODO: Combine into `_NDAxisIter` with `axis=ndim-1`.
"""
An iterator yielding the array elements according to the order.
"""

var ptr: UnsafePointer[Scalar[dtype]]
var length: Int
var ndim: Int
Expand Down
17 changes: 15 additions & 2 deletions tests/core/test_array_methods.mojo
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import numojo as nm
from numojo import *
from numojo.prelude import *
from testing.testing import assert_true, assert_almost_equal, assert_equal
from utils_for_test import check, check_is_close

Expand Down Expand Up @@ -58,3 +57,17 @@ def test_constructors():
arr5.shape[2] == 5,
"NDArray constructor with NDArrayShape: shape element 2",
)


def test_iterator():
var a = nm.arange[i8](24).reshape(Shape(2, 3, 4))
var a_iter = a.iter_by_axis[forward=False](axis=0)
var b = a_iter.__next__() == nm.array[i8]("[11, 23]")
assert_true(
b.item(0) == True,
"`_NDAxisIter` breaks",
)
assert_true(
b.item(1) == True,
"`_NDAxisIter` breaks",
)