Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions mojoproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,21 @@ p = "clear && magic run package"
format = "magic run mojo format ./"

# test whether tests pass on the built package
test = "magic run package && magic run mojo test tests -I ./tests/"
test = "magic run package && magic run mojo test tests -I tests/ && rm tests/numojo.mojopkg"
t = "clear && magic run test"

# run individual tests to avoid overheat
test_core = "magic run mojo test tests/core -I ./ -I ./tests/"
test_creation = "magic run mojo test tests/routines/test_creation.mojo -I ./ -I ./tests/"
test_functional = "magic run mojo test tests/routines/test_functional.mojo -I ./ -I ./tests/"
test_indexing = "magic run mojo test tests/routines/test_indexing.mojo -I ./ -I ./tests/"
test_linalg = "magic run mojo test tests/routines/test_linalg.mojo -I ./ -I ./tests/"
test_manipulation = "magic run mojo test tests/routines/test_manipulation.mojo -I ./ -I ./tests/"
test_math = "magic run mojo test tests/routines/test_math.mojo -I ./ -I ./tests/"
test_random = "magic run mojo test tests/routines/test_random.mojo -I ./ -I ./tests/"
test_statistics = "magic run mojo test tests/routines/test_statistics.mojo -I ./ -I ./tests/"
test_sorting = "magic run mojo test tests/routines/test_sorting.mojo -I ./ -I ./tests/"
test_core = "magic run package && magic run mojo test tests/core -I tests/ && rm tests/numojo.mojopkg"
test_creation = "magic run package && magic run mojo test tests/routines/test_creation.mojo -I tests/ && rm tests/numojo.mojopkg"
test_functional = "magic run package && magic run mojo test tests/routines/test_functional.mojo -I tests/ && rm tests/numojo.mojopkg"
test_indexing = "magic run package && magic run mojo test tests/routines/test_indexing.mojo -I tests/ && rm tests/numojo.mojopkg"
test_linalg = "magic run package && magic run mojo test tests/routines/test_linalg.mojo -I tests/ && rm tests/numojo.mojopkg"
test_manipulation = "magic run package && magic run mojo test tests/routines/test_manipulation.mojo -I tests/ && rm tests/numojo.mojopkg"
test_math = "magic run package && magic run mojo test tests/routines/test_math.mojo -I tests/ && rm tests/numojo.mojopkg"
test_random = "magic run package && magic run mojo test tests/routines/test_random.mojo -I tests/ && rm tests/numojo.mojopkg"
test_statistics = "magic run package && magic run mojo test tests/routines/test_statistics.mojo -I tests/ && rm tests/numojo.mojopkg"
test_sorting = "magic run package && magic run mojo test tests/routines/test_sorting.mojo -I tests/ && rm tests/numojo.mojopkg"
test_searching = "magic run package && magic run mojo test tests/routines/test_searching.mojo -I tests/ && rm tests/numojo.mojopkg"

# run all final checks before a commit
final = "magic run format && magic run test"
Expand Down
51 changes: 23 additions & 28 deletions numojo/core/ndarray.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ from numojo.routines.io.formatting import (
import numojo.routines.logic.comparison as comparison
import numojo.routines.math.arithmetic as arithmetic
import numojo.routines.math.rounding as rounding
import numojo.routines.searching as searching


struct NDArray[dtype: DType = DType.float64](
Expand Down Expand Up @@ -3630,42 +3631,36 @@ struct NDArray[dtype: DType = DType.float64](
vectorize[vectorized_any, self.width](self.size)
return result

fn argmax(self) raises -> Int:
fn argmax(self) raises -> Scalar[DType.index]:
"""Returns the indices of the maximum values along an axis.
When no axis is specified, the array is flattened.
See `numojo.argmax()` for more details.
"""
Get location in pointer of max value.
return searching.argmax(self)

Returns:
Index of the maximum value.
"""
var result: Int = 0
var max_val: SIMD[dtype, 1] = self.load[width=1](0)
for i in range(1, self.size):
var temp: SIMD[dtype, 1] = self.load[width=1](i)
if temp > max_val:
max_val = temp
result = i
return result
fn argmax(self, axis: Int) raises -> NDArray[DType.index]:
"""Returns the indices of the maximum values along an axis.
See `numojo.argmax()` for more details.
"""
return searching.argmax(self, axis=axis)

fn argmin(self) raises -> Int:
fn argmin(self) raises -> Scalar[DType.index]:
"""Returns the indices of the minimum values along an axis.
When no axis is specified, the array is flattened.
See `numojo.argmin()` for more details.
"""
Get location in pointer of min value.
return searching.argmin(self)

Returns:
Index of the minimum value.
"""
var result: Int = 0
var min_val: SIMD[dtype, 1] = self.load[width=1](0)
for i in range(1, self.size):
var temp: SIMD[dtype, 1] = self.load[width=1](i)
if temp < min_val:
min_val = temp
result = i
return result
fn argmin(self, axis: Int) raises -> NDArray[DType.index]:
"""Returns the indices of the minimum values along an axis.
See `numojo.argmin()` for more details.
"""
return searching.argmin(self, axis=axis)

fn argsort(self) raises -> NDArray[DType.index]:
"""
Sort the NDArray and return the sorted indices.
See `numojo.argsort()`.
See `numojo.argsort()` for more details.

Returns:
The indices of the sorted NDArray.
Expand All @@ -3676,7 +3671,7 @@ struct NDArray[dtype: DType = DType.float64](
fn argsort(self, axis: Int) raises -> NDArray[DType.index]:
"""
Sort the NDArray and return the sorted indices.
See `numojo.argsort()`.
See `numojo.argsort()` for more details.

Returns:
The indices of the sorted NDArray.
Expand Down
56 changes: 56 additions & 0 deletions numojo/routines/functional.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ from numojo.core.ndstrides import NDArrayStrides
# `a` is a 1-d array slice of the original array along given axis.
# ===----------------------------------------------------------------------=== #

# The following overloads of `apply_along_axis` are for the case when the
# dimension of the input array is reduced.


fn apply_along_axis[
dtype: DType,
Expand Down Expand Up @@ -76,6 +79,55 @@ fn apply_along_axis[
return res^


fn apply_along_axis[
dtype: DType,
func1d: fn[dtype_func: DType] (NDArray[dtype_func]) raises -> Scalar[
DType.index
],
](a: NDArray[dtype], axis: Int) raises -> NDArray[DType.index]:
"""
Applies a function to a NDArray by axis and reduce that dimension.
The returned data type is DType.index.
When the array is 1-d, the returned array will be a 0-d array.

Parameters:
dtype: The data type of the input NDArray elements.
func1d: The function to apply to the NDArray.

Args:
a: The NDArray to apply the function to.
axis: The axis to apply the function to.

Returns:
The NDArray with the function applied to the input NDArray by axis.
"""

# The iterator along the axis
var iterator = a.iter_along_axis(axis=axis)
# The final output array will have 1 less dimension than the input array
var res: NDArray[DType.index]

if a.ndim == 1:
res = numojo.creation._0darray[DType.index](0)
(res._buf.ptr).init_pointee_copy(func1d[dtype](a))

else:
res = NDArray[DType.index](a.shape._pop(axis=axis))

@parameter
fn parallelized_func(i: Int):
try:
(res._buf.ptr + i).init_pointee_copy(
func1d[dtype](iterator.ith(i))
)
except e:
print("Error in parallelized_func", e)

parallelize[parallelized_func](a.size // a.shape[axis])

return res^


fn apply_along_axis[
dtype: DType, //,
returned_dtype: DType,
Expand Down Expand Up @@ -131,6 +183,10 @@ fn apply_along_axis[
return res^


# The following overloads of `apply_along_axis` are for the case when the
# dimension of the input array is not reduced.


fn apply_along_axis[
dtype: DType, //,
func1d: fn[dtype_func: DType] (NDArray[dtype_func]) raises -> NDArray[
Expand Down
4 changes: 2 additions & 2 deletions numojo/routines/math/extrema.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ fn max[dtype: DType](a: NDArray[dtype], axis: Int) raises -> NDArray[dtype]:
normalized_axis += a.ndim
if (normalized_axis < 0) or (normalized_axis >= a.ndim):
raise Error(
String("Error in `mean`: Axis {} not in bound [-{}, {})").format(
String("Error in `max`: Axis {} not in bound [-{}, {})").format(
axis, a.ndim, a.ndim
)
)
Expand Down Expand Up @@ -243,7 +243,7 @@ fn min[dtype: DType](a: NDArray[dtype], axis: Int) raises -> NDArray[dtype]:
normalized_axis += a.ndim
if (normalized_axis < 0) or (normalized_axis >= a.ndim):
raise Error(
String("Error in `mean`: Axis {} not in bound [-{}, {})").format(
String("Error in `min`: Axis {} not in bound [-{}, {})").format(
axis, a.ndim, a.ndim
)
)
Expand Down
Loading