diff --git a/mojoproject.toml b/mojoproject.toml index ce4f6ff9..996322da 100644 --- a/mojoproject.toml +++ b/mojoproject.toml @@ -24,20 +24,21 @@ p = "clear && magic run package" format = "magic run mojo format ./" # test whether tests pass on the built package -test = "magic run package && magic run mojo test tests -I ./tests/" +test = "magic run package && magic run mojo test tests -I tests/ && rm tests/numojo.mojopkg" t = "clear && magic run test" # run individual tests to avoid overheat -test_core = "magic run mojo test tests/core -I ./ -I ./tests/" -test_creation = "magic run mojo test tests/routines/test_creation.mojo -I ./ -I ./tests/" -test_functional = "magic run mojo test tests/routines/test_functional.mojo -I ./ -I ./tests/" -test_indexing = "magic run mojo test tests/routines/test_indexing.mojo -I ./ -I ./tests/" -test_linalg = "magic run mojo test tests/routines/test_linalg.mojo -I ./ -I ./tests/" -test_manipulation = "magic run mojo test tests/routines/test_manipulation.mojo -I ./ -I ./tests/" -test_math = "magic run mojo test tests/routines/test_math.mojo -I ./ -I ./tests/" -test_random = "magic run mojo test tests/routines/test_random.mojo -I ./ -I ./tests/" -test_statistics = "magic run mojo test tests/routines/test_statistics.mojo -I ./ -I ./tests/" -test_sorting = "magic run mojo test tests/routines/test_sorting.mojo -I ./ -I ./tests/" +test_core = "magic run package && magic run mojo test tests/core -I tests/ && rm tests/numojo.mojopkg" +test_creation = "magic run package && magic run mojo test tests/routines/test_creation.mojo -I tests/ && rm tests/numojo.mojopkg" +test_functional = "magic run package && magic run mojo test tests/routines/test_functional.mojo -I tests/ && rm tests/numojo.mojopkg" +test_indexing = "magic run package && magic run mojo test tests/routines/test_indexing.mojo -I tests/ && rm tests/numojo.mojopkg" +test_linalg = "magic run package && magic run mojo test tests/routines/test_linalg.mojo -I tests/ && rm tests/numojo.mojopkg" +test_manipulation = "magic run package && magic run mojo test tests/routines/test_manipulation.mojo -I tests/ && rm tests/numojo.mojopkg" +test_math = "magic run package && magic run mojo test tests/routines/test_math.mojo -I tests/ && rm tests/numojo.mojopkg" +test_random = "magic run package && magic run mojo test tests/routines/test_random.mojo -I tests/ && rm tests/numojo.mojopkg" +test_statistics = "magic run package && magic run mojo test tests/routines/test_statistics.mojo -I tests/ && rm tests/numojo.mojopkg" +test_sorting = "magic run package && magic run mojo test tests/routines/test_sorting.mojo -I tests/ && rm tests/numojo.mojopkg" +test_searching = "magic run package && magic run mojo test tests/routines/test_searching.mojo -I tests/ && rm tests/numojo.mojopkg" # run all final checks before a commit final = "magic run format && magic run test" diff --git a/numojo/core/ndarray.mojo b/numojo/core/ndarray.mojo index 890919d7..edf851b4 100644 --- a/numojo/core/ndarray.mojo +++ b/numojo/core/ndarray.mojo @@ -88,6 +88,7 @@ from numojo.routines.io.formatting import ( import numojo.routines.logic.comparison as comparison import numojo.routines.math.arithmetic as arithmetic import numojo.routines.math.rounding as rounding +import numojo.routines.searching as searching struct NDArray[dtype: DType = DType.float64]( @@ -3630,42 +3631,36 @@ struct NDArray[dtype: DType = DType.float64]( vectorize[vectorized_any, self.width](self.size) return result - fn argmax(self) raises -> Int: + fn argmax(self) raises -> Scalar[DType.index]: + """Returns the indices of the maximum values along an axis. + When no axis is specified, the array is flattened. + See `numojo.argmax()` for more details. """ - Get location in pointer of max value. + return searching.argmax(self) - Returns: - Index of the maximum value. - """ - var result: Int = 0 - var max_val: SIMD[dtype, 1] = self.load[width=1](0) - for i in range(1, self.size): - var temp: SIMD[dtype, 1] = self.load[width=1](i) - if temp > max_val: - max_val = temp - result = i - return result + fn argmax(self, axis: Int) raises -> NDArray[DType.index]: + """Returns the indices of the maximum values along an axis. + See `numojo.argmax()` for more details. + """ + return searching.argmax(self, axis=axis) - fn argmin(self) raises -> Int: + fn argmin(self) raises -> Scalar[DType.index]: + """Returns the indices of the minimum values along an axis. + When no axis is specified, the array is flattened. + See `numojo.argmin()` for more details. """ - Get location in pointer of min value. + return searching.argmin(self) - Returns: - Index of the minimum value. - """ - var result: Int = 0 - var min_val: SIMD[dtype, 1] = self.load[width=1](0) - for i in range(1, self.size): - var temp: SIMD[dtype, 1] = self.load[width=1](i) - if temp < min_val: - min_val = temp - result = i - return result + fn argmin(self, axis: Int) raises -> NDArray[DType.index]: + """Returns the indices of the minimum values along an axis. + See `numojo.argmin()` for more details. + """ + return searching.argmin(self, axis=axis) fn argsort(self) raises -> NDArray[DType.index]: """ Sort the NDArray and return the sorted indices. - See `numojo.argsort()`. + See `numojo.argsort()` for more details. Returns: The indices of the sorted NDArray. @@ -3676,7 +3671,7 @@ struct NDArray[dtype: DType = DType.float64]( fn argsort(self, axis: Int) raises -> NDArray[DType.index]: """ Sort the NDArray and return the sorted indices. - See `numojo.argsort()`. + See `numojo.argsort()` for more details. Returns: The indices of the sorted NDArray. diff --git a/numojo/routines/functional.mojo b/numojo/routines/functional.mojo index 262d6872..20355f86 100644 --- a/numojo/routines/functional.mojo +++ b/numojo/routines/functional.mojo @@ -27,6 +27,9 @@ from numojo.core.ndstrides import NDArrayStrides # `a` is a 1-d array slice of the original array along given axis. # ===----------------------------------------------------------------------=== # +# The following overloads of `apply_along_axis` are for the case when the +# dimension of the input array is reduced. + fn apply_along_axis[ dtype: DType, @@ -76,6 +79,55 @@ fn apply_along_axis[ return res^ +fn apply_along_axis[ + dtype: DType, + func1d: fn[dtype_func: DType] (NDArray[dtype_func]) raises -> Scalar[ + DType.index + ], +](a: NDArray[dtype], axis: Int) raises -> NDArray[DType.index]: + """ + Applies a function to a NDArray by axis and reduce that dimension. + The returned data type is DType.index. + When the array is 1-d, the returned array will be a 0-d array. + + Parameters: + dtype: The data type of the input NDArray elements. + func1d: The function to apply to the NDArray. + + Args: + a: The NDArray to apply the function to. + axis: The axis to apply the function to. + + Returns: + The NDArray with the function applied to the input NDArray by axis. + """ + + # The iterator along the axis + var iterator = a.iter_along_axis(axis=axis) + # The final output array will have 1 less dimension than the input array + var res: NDArray[DType.index] + + if a.ndim == 1: + res = numojo.creation._0darray[DType.index](0) + (res._buf.ptr).init_pointee_copy(func1d[dtype](a)) + + else: + res = NDArray[DType.index](a.shape._pop(axis=axis)) + + @parameter + fn parallelized_func(i: Int): + try: + (res._buf.ptr + i).init_pointee_copy( + func1d[dtype](iterator.ith(i)) + ) + except e: + print("Error in parallelized_func", e) + + parallelize[parallelized_func](a.size // a.shape[axis]) + + return res^ + + fn apply_along_axis[ dtype: DType, //, returned_dtype: DType, @@ -131,6 +183,10 @@ fn apply_along_axis[ return res^ +# The following overloads of `apply_along_axis` are for the case when the +# dimension of the input array is not reduced. + + fn apply_along_axis[ dtype: DType, //, func1d: fn[dtype_func: DType] (NDArray[dtype_func]) raises -> NDArray[ diff --git a/numojo/routines/math/extrema.mojo b/numojo/routines/math/extrema.mojo index 6995dd7f..e9b461ea 100644 --- a/numojo/routines/math/extrema.mojo +++ b/numojo/routines/math/extrema.mojo @@ -131,7 +131,7 @@ fn max[dtype: DType](a: NDArray[dtype], axis: Int) raises -> NDArray[dtype]: normalized_axis += a.ndim if (normalized_axis < 0) or (normalized_axis >= a.ndim): raise Error( - String("Error in `mean`: Axis {} not in bound [-{}, {})").format( + String("Error in `max`: Axis {} not in bound [-{}, {})").format( axis, a.ndim, a.ndim ) ) @@ -243,7 +243,7 @@ fn min[dtype: DType](a: NDArray[dtype], axis: Int) raises -> NDArray[dtype]: normalized_axis += a.ndim if (normalized_axis < 0) or (normalized_axis >= a.ndim): raise Error( - String("Error in `mean`: Axis {} not in bound [-{}, {})").format( + String("Error in `min`: Axis {} not in bound [-{}, {})").format( axis, a.ndim, a.ndim ) ) diff --git a/numojo/routines/searching.mojo b/numojo/routines/searching.mojo index d57b006b..bbde908a 100644 --- a/numojo/routines/searching.mojo +++ b/numojo/routines/searching.mojo @@ -2,6 +2,7 @@ # Searching # ===----------------------------------------------------------------------=== # +import builtin.math as builtin_math import math from algorithm import vectorize from sys import simdwidthof @@ -16,29 +17,145 @@ from numojo.routines.sorting import binary_sort from numojo.routines.math.extrema import _max, _min -# * for loop version works fine for argmax and argmin, need to vectorize it -fn argmax[dtype: DType](array: NDArray[dtype]) raises -> Int: +fn argmax_1d[dtype: DType](a: NDArray[dtype]) raises -> Scalar[DType.index]: + """Returns the index of the maximum value in the buffer. + Regardless of the shape of input, it is treated as a 1-d array. + + Parameters: + dtype: The element type. + + Args: + a: An array. + + Returns: + The index of the maximum value in the buffer. + """ + + var ptr = a._buf.ptr + var value = ptr[] + var result: Int = 0 + + for i in range(a.size): + if ptr[] > value: + result = i + value = ptr[] + ptr += 1 + + return result + + +fn argmin_1d[dtype: DType](a: NDArray[dtype]) raises -> Scalar[DType.index]: + """Returns the index of the minimum value in the buffer. + Regardless of the shape of input, it is treated as a 1-d array. + + Parameters: + dtype: The element type. + + Args: + a: An array. + + Returns: + The index of the minimum value in the buffer. """ - Argmax of a array. + + var ptr = a._buf.ptr + var value = ptr[] + var result: Int = 0 + + for i in range(a.size): + if ptr[] < value: + result = i + value = ptr[] + ptr += 1 + + return result + + +fn argmax[dtype: DType, //](a: NDArray[dtype]) raises -> Scalar[DType.index]: + """Returns the indices of the maximum values of the array along an axis. + When no axis is specified, the array is flattened. Parameters: dtype: The element type. Args: - array: A array. + a: An array. + Returns: - The index of the maximum value of the array. + Returns the indices of the maximum values of the array along an axis. + + Notes: + + If there are multiple occurrences of the maximum values, the indices + of the first occurrence are returned. """ - if array.size == 0: - raise Error("array is empty") - var idx: Int = 0 - var max_val: Scalar[dtype] = array.load(0) - for i in range(1, array.size): - if array.load(i) > max_val: - max_val = array.load(i) - idx = i - return idx + if a.ndim == 1: + return argmax_1d(a) + else: + return argmax_1d(ravel(a)) + + +fn argmax[ + dtype: DType, // +](a: NDArray[dtype], axis: Int) raises -> NDArray[DType.index]: + """Returns the indices of the maximum values of the array along an axis. + When no axis is specified, the array is flattened. + + Parameters: + dtype: The element type. + + Args: + a: An array. + axis: The axis along which to operate. + + Returns: + Returns the indices of the maximum values of the array along an axis. + + Notes: + + If there are multiple occurrences of the maximum values, the indices + of the first occurrence are returned. + + Examples: + + ```mojo + from numojo.prelude import * + from python import Python + + fn main() raises: + var np = Python.import_module("numpy") + # Test with argmax to get maximum values + var a = nm.random.randint(5, 4, low=0, high=10) + var a_np = a.to_numpy() + print(a) + print(a_np) + # Get indices of maximum values along axis=1 + var max_indices = nm.argmax(a, axis=1) + var max_indices_np = np.argmax(a_np, axis=1) + # Reshape indices for take_along_axis + var reshaped_indices = max_indices.reshape(Shape(max_indices.shape[0], 1)) + var reshaped_indices_np = max_indices_np.reshape(max_indices_np.shape[0], 1) + print(reshaped_indices) + print(reshaped_indices_np) + # Get maximum values using take_along_axis + print(nm.indexing.take_along_axis(a, reshaped_indices, axis=1)) + print(np.take_along_axis(a_np, reshaped_indices_np, axis=1)) + ``` + End of examples. + """ + + var normalized_axis = axis + if axis < 0: + normalized_axis += a.ndim + if (normalized_axis < 0) or (normalized_axis >= a.ndim): + raise Error( + String("Error in `argmax`: Axis {} not in bound [-{}, {})").format( + axis, a.ndim, a.ndim + ) + ) + + return numojo.apply_along_axis[func1d=argmax_1d](a=a, axis=normalized_axis) fn argmax[dtype: DType](A: Matrix[dtype]) raises -> Scalar[DType.index]: @@ -76,28 +193,64 @@ fn argmax[ raise Error(String("The axis can either be 1 or 0!")) -fn argmin[dtype: DType](array: NDArray[dtype]) raises -> Int: +fn argmin[dtype: DType, //](a: NDArray[dtype]) raises -> Scalar[DType.index]: + """Returns the indices of the minimum values of the array along an axis. + When no axis is specified, the array is flattened. + + Parameters: + dtype: The element type. + + Args: + a: An array. + + Returns: + Returns the indices of the minimum values of the array along an axis. + + Notes: + + If there are multiple occurrences of the minimum values, the indices + of the first occurrence are returned. """ - Argmin of a array. + + if a.ndim == 1: + return argmin_1d(a) + else: + return argmin_1d(ravel(a)) + + +fn argmin[ + dtype: DType, // +](a: NDArray[dtype], axis: Int) raises -> NDArray[DType.index]: + """Returns the indices of the minimum values of the array along an axis. + When no axis is specified, the array is flattened. + Parameters: dtype: The element type. Args: - array: A array. + a: An array. + axis: The axis along which to operate. + Returns: - The index of the minimum value of the array. + Returns the indices of the minimum values of the array along an axis. + + Notes: + + If there are multiple occurrences of the minimum values, the indices + of the first occurrence are returned. """ - if array.size == 0: - raise Error("array is empty") - var idx: Int = 0 - var min_val: Scalar[dtype] = array.load(0) + var normalized_axis = axis + if axis < 0: + normalized_axis += a.ndim + if (normalized_axis < 0) or (normalized_axis >= a.ndim): + raise Error( + String("Error in `argmin`: Axis {} not in bound [-{}, {})").format( + axis, a.ndim, a.ndim + ) + ) - for i in range(1, array.size): - if array.load(i) < min_val: - min_val = array.load(i) - idx = i - return idx + return numojo.apply_along_axis[func1d=argmin_1d](a=a, axis=normalized_axis) fn argmin[dtype: DType](A: Matrix[dtype]) raises -> Scalar[DType.index]: diff --git a/tests/routines/test_searching.mojo b/tests/routines/test_searching.mojo new file mode 100644 index 00000000..88b7561c --- /dev/null +++ b/tests/routines/test_searching.mojo @@ -0,0 +1,221 @@ +from numojo.prelude import * +from python import Python, PythonObject +from utils_for_test import check, check_is_close, check_values_close + + +fn test_argmax() raises: + var np = Python.import_module("numpy") + + # Test 1D array + var a1d = nm.array[nm.f32]("[3.4, 1.2, 5.7, 0.9, 2.3]") + var a1d_np = a1d.to_numpy() + + check_values_close( + nm.argmax(a1d), + np.argmax(a1d_np), + "`argmax` with 1D array is broken", + ) + + # Test 2D array without specifying axis (flattened) + var a2d = nm.array[nm.f32]( + "[[3.4, 1.2, 5.7], [0.9, 2.3, 4.1], [7.6, 0.5, 2.8]]" + ) + var a2d_np = a2d.to_numpy() + + check_values_close( + nm.argmax(a2d), + np.argmax(a2d_np), + "`argmax` with 2D array (flattened) is broken", + ) + + # Test 2D array with axis=0 + check( + nm.argmax(a2d, axis=0), + np.argmax(a2d_np, axis=0), + "`argmax` with 2D array on axis=0 is broken", + ) + + # Test 2D array with axis=1 + check( + nm.argmax(a2d, axis=1), + np.argmax(a2d_np, axis=1), + "`argmax` with 2D array on axis=1 is broken", + ) + + # Test 2D array with negative axis + check( + nm.argmax(a2d, axis=-1), + np.argmax(a2d_np, axis=-1), + "`argmax` with 2D array on negative axis is broken", + ) + + # Test 3D array + var a3d = nm.random.randint(2, 3, 4, low=0, high=10) + var a3d_np = a3d.to_numpy() + + check_values_close( + nm.argmax(a3d), + np.argmax(a3d_np), + "`argmax` with 3D array (flattened) is broken", + ) + + check( + nm.argmax(a3d, axis=0), + np.argmax(a3d_np, axis=0), + "`argmax` with 3D array on axis=0 is broken", + ) + + check( + nm.argmax(a3d, axis=1), + np.argmax(a3d_np, axis=1), + "`argmax` with 3D array on axis=1 is broken", + ) + + check( + nm.argmax(a3d, axis=2), + np.argmax(a3d_np, axis=2), + "`argmax` with 3D array on axis=2 is broken", + ) + + # Test with F-order array + var a3d_f = nm.random.randint(2, 3, 4, low=0, high=10).reshape( + Shape(2, 3, 4), order="F" + ) + var a3d_f_np = a3d_f.to_numpy() + + for i in range(3): + check( + nm.argmax(a3d_f, axis=i), + np.argmax(a3d_f_np, axis=i), + "`argmax` with F-order 3D array on axis={} is broken".format(i), + ) + + +fn test_argmin() raises: + var np = Python.import_module("numpy") + + # Test 1D array + var a1d = nm.array[nm.f32]("[3.4, 1.2, 5.7, 0.9, 2.3]") + var a1d_np = a1d.to_numpy() + + check_values_close( + nm.argmin(a1d), + np.argmin(a1d_np), + "`argmin` with 1D array is broken", + ) + + # Test 2D array without specifying axis (flattened) + var a2d = nm.array[nm.f32]( + "[[3.4, 1.2, 5.7], [0.9, 2.3, 4.1], [7.6, 0.5, 2.8]]" + ) + var a2d_np = a2d.to_numpy() + + check_values_close( + nm.argmin(a2d), + np.argmin(a2d_np), + "`argmin` with 2D array (flattened) is broken", + ) + + # Test 2D array with axis=0 + check( + nm.argmin(a2d, axis=0), + np.argmin(a2d_np, axis=0), + "`argmin` with 2D array on axis=0 is broken", + ) + + # Test 2D array with axis=1 + check( + nm.argmin(a2d, axis=1), + np.argmin(a2d_np, axis=1), + "`argmin` with 2D array on axis=1 is broken", + ) + + # Test 2D array with negative axis + check( + nm.argmin(a2d, axis=-1), + np.argmin(a2d_np, axis=-1), + "`argmin` with 2D array on negative axis is broken", + ) + + # Test 3D array + var a3d = nm.random.randint(2, 3, 4, low=0, high=10) + var a3d_np = a3d.to_numpy() + + check_values_close( + nm.argmin(a3d), + np.argmin(a3d_np), + "`argmin` with 3D array (flattened) is broken", + ) + + check( + nm.argmin(a3d, axis=0), + np.argmin(a3d_np, axis=0), + "`argmin` with 3D array on axis=0 is broken", + ) + + check( + nm.argmin(a3d, axis=1), + np.argmin(a3d_np, axis=1), + "`argmin` with 3D array on axis=1 is broken", + ) + + check( + nm.argmin(a3d, axis=2), + np.argmin(a3d_np, axis=2), + "`argmin` with 3D array on axis=2 is broken", + ) + + # Test with F-order array + var a3d_f = nm.random.randint(2, 3, 4, low=0, high=10).reshape( + Shape(2, 3, 4), order="F" + ) + var a3d_f_np = a3d_f.to_numpy() + + for i in range(3): + check( + nm.argmin(a3d_f, axis=i), + np.argmin(a3d_f_np, axis=i), + "`argmin` with F-order 3D array on axis={} is broken".format(i), + ) + + +fn test_take_along_axis_with_argmax_argmin() raises: + var np = Python.import_module("numpy") + + # Test with argmax to get maximum values + var a2d = nm.random.randint(5, 4, low=0, high=10) + var a2d_np = a2d.to_numpy() + + # Get indices of maximum values along axis=1 + var max_indices = nm.argmax(a2d, axis=1) + var max_indices_np = np.argmax(a2d_np, axis=1) + + # Reshape indices for take_along_axis + var reshaped_indices = max_indices.reshape(Shape(max_indices.shape[0], 1)) + var reshaped_indices_np = max_indices_np.reshape(max_indices_np.shape[0], 1) + + # Get maximum values using take_along_axis + check( + nm.indexing.take_along_axis(a2d, reshaped_indices, axis=1), + np.take_along_axis(a2d_np, reshaped_indices_np, axis=1), + "`take_along_axis` with argmax is broken", + ) + + # Test with argmin to get minimum values + var min_indices = nm.argmin(a2d, axis=1) + var min_indices_np = np.argmin(a2d_np, axis=1) + + # Reshape indices for take_along_axis + var reshaped_min_indices = min_indices.reshape( + Shape(min_indices.shape[0], 1) + ) + var reshaped_min_indices_np = min_indices_np.reshape( + min_indices_np.shape[0], 1 + ) + + # Get minimum values using take_along_axis + check( + nm.indexing.take_along_axis(a2d, reshaped_min_indices, axis=1), + np.take_along_axis(a2d_np, reshaped_min_indices_np, axis=1), + "`take_along_axis` with argmin is broken", + ) diff --git a/tests/utils_for_test.mojo b/tests/utils_for_test.mojo index f1a211d6..ed012ec4 100644 --- a/tests/utils_for_test.mojo +++ b/tests/utils_for_test.mojo @@ -4,7 +4,7 @@ import numojo as nm fn check[ - dtype: DType + dtype: DType, // ](array: nm.NDArray[dtype], np_sol: PythonObject, st: String) raises: var np = Python.import_module("numpy") assert_true(np.all(np.equal(array.to_numpy(), np_sol)), st)