diff --git a/.gitignore b/.gitignore index 6b524117..496e0a39 100644 --- a/.gitignore +++ b/.gitignore @@ -27,7 +27,7 @@ magic.lock # Miscellaneous files mojo -/numojo.mojopkg +numojo.mojopkg /bench.mojo /test*.mojo /test*.ipynb diff --git a/mojoproject.toml b/mojoproject.toml index 3e852a85..ce4f6ff9 100644 --- a/mojoproject.toml +++ b/mojoproject.toml @@ -16,15 +16,15 @@ license = "Apache-2.0" readme = "README.MD" [tasks] -# compile the package -package = "magic run mojo package numojo" +# compile the package and copy it to the tests folder +package = "magic run mojo package numojo && cp numojo.mojopkg tests/" p = "clear && magic run package" # format the package -format = "magic run mojo format ./ && magic run mojo format docs/readthedocs/docs.py" +format = "magic run mojo format ./" -# test whether tests pass and the package can be built -test = "magic run mojo test tests -I ./ -I ./tests/ && magic run package" +# test whether tests pass on the built package +test = "magic run package && magic run mojo test tests -I ./tests/" t = "clear && magic run test" # run individual tests to avoid overheat @@ -40,8 +40,8 @@ test_statistics = "magic run mojo test tests/routines/test_statistics.mojo -I ./ test_sorting = "magic run mojo test tests/routines/test_sorting.mojo -I ./ -I ./tests/" # run all final checks before a commit -final = "magic run test && magic run format && magic run package" -f = "clear && magic run test && magic run format && magic run package" +final = "magic run format && magic run test" +f = "clear && magic run final" # Automatically Generate doc pages doc_pages = "mojo doc numojo/ -o docs.json" @@ -50,7 +50,7 @@ doc_pages = "mojo doc numojo/ -o docs.json" release = "clear && magic run final && magic run doc_pages" [dependencies] -max = "=25.1" +max = "=25.1.1" python = ">=3.11" numpy = ">=1.19" scipy = ">=1.14" \ No newline at end of file diff --git a/numojo/__init__.mojo b/numojo/__init__.mojo index e1d44323..87e82680 100644 --- a/numojo/__init__.mojo +++ b/numojo/__init__.mojo @@ -62,7 +62,7 @@ from numojo.routines.io import ( loadtxt, savetxt, ) -from numojo.routines.io import printoptions, set_printoptions +from numojo.routines.io import set_printoptions from numojo.routines import linalg from numojo.routines.linalg.misc import diagonal @@ -181,7 +181,7 @@ from numojo.routines.creation import ( ) from numojo.routines import indexing -from numojo.routines.indexing import where, compress +from numojo.routines.indexing import where, compress, take_along_axis from numojo.routines.functional import apply_along_axis diff --git a/numojo/core/complex/complex_ndarray.mojo b/numojo/core/complex/complex_ndarray.mojo index d836720d..2e7583ae 100644 --- a/numojo/core/complex/complex_ndarray.mojo +++ b/numojo/core/complex/complex_ndarray.mojo @@ -34,7 +34,6 @@ from numojo.routines.io.formatting import ( format_floating_scientific, format_value, PrintOptions, - printoptions, GLOBAL_PRINT_OPTIONS, ) import numojo.routines.linalg as linalg diff --git a/numojo/routines/indexing.mojo b/numojo/routines/indexing.mojo index 421eb825..79e2d0c4 100644 --- a/numojo/routines/indexing.mojo +++ b/numojo/routines/indexing.mojo @@ -232,3 +232,126 @@ fn compress[ else: return compress(condition, ravel(a), axis=0) + + +fn take_along_axis[ + dtype: DType, //, +]( + arr: NDArray[dtype], indices: NDArray[DType.index], axis: Int = 0 +) raises -> NDArray[dtype]: + """ + Takes values from the input array along the given axis based on indices. + + Raises: + Error: If the axis is out of bounds for the given array. + Error: If the ndim of arr and indices are not the same. + Error: If the shape of indices does not match the shape of the + input array except along the given axis. + + Parameters: + dtype: DType of the input array. + + Args: + arr: The source array. + indices: The indices array. + axis: The axis along which to take values. Default is 0. + + Returns: + An array with the same shape as indices with values taken from the + input array along the given axis. + + Examples: + + ```console + > var a = nm.arange[i8](12).reshape(Shape(3, 4)) + > print(a) + [[ 0 1 2 3] + [ 4 5 6 7] + [ 8 9 10 11]] + > ind = nm.array[intp]("[[0, 1, 2, 0], [1, 0, 2, 1]]") + > print(ind) + [[0 1 2 0] + [1 0 2 1]] + > print(nm.indexing.take_along_axis(a, ind, axis=0)) + [[ 0 5 10 3] + [ 4 1 10 7]] + ``` + . + """ + var normalized_axis = axis + if normalized_axis < 0: + normalized_axis = arr.ndim + normalized_axis + if (normalized_axis >= arr.ndim) or (normalized_axis < 0): + raise Error( + String( + "\nError in `take_along_axis`: Axis {} is out of bound for" + " array with {} dimensions" + ).format(axis, arr.ndim) + ) + + # Check if the ndim of arr and indices are same + if arr.ndim != indices.ndim: + raise Error( + String( + "\nError in `take_along_axis`: The ndim of arr and indices must" + " be same. Got {} and {}.".format(arr.ndim, indices.ndim) + ) + ) + + # broadcast indices to the shape of arr if necessary + # When broadcasting, the shape of indices must match the shape of arr + # except along the axis + + var broadcasted_indices = indices + + if arr.shape != indices.shape: + var arr_shape_new = arr.shape + arr_shape_new[normalized_axis] = indices.shape[normalized_axis] + + try: + broadcasted_indices = numojo.broadcast_to(indices, arr_shape_new) + except e: + raise Error( + String( + "\nError in `take_along_axis`: Shape of indices must match" + " shape of array except along the given axis. " + + String(e) + ) + ) + + # Create output array with same shape as broadcasted_indices + var result = NDArray[dtype](Shape(broadcasted_indices.shape)) + + var arr_iterator = arr.iter_along_axis(normalized_axis) + var indices_iterator = broadcasted_indices.iter_along_axis(normalized_axis) + var length_of_iterator = result.size // result.shape[normalized_axis] + + if normalized_axis == arr.ndim - 1: + # If axis is the last axis, the data is contiguous. + for i in range(length_of_iterator): + var arr_slice = arr_iterator.ith(i) + var indices_slice = indices_iterator.ith(i) + var arr_slice_after_applying_indices = arr_slice[indices_slice] + memcpy( + result._buf.ptr + i * result.shape[normalized_axis], + arr_slice_after_applying_indices._buf.ptr, + result.shape[normalized_axis], + ) + else: + # If axis is not the last axis, the data is not contiguous. + for i in range(length_of_iterator): + var indices_slice_offsets: NDArray[DType.index] + var indices_slice: NDArray[DType.index] + indices_slice_offsets, indices_slice = ( + indices_iterator.ith_with_offsets(i) + ) + var arr_slice = arr_iterator.ith(i) + var arr_slice_after_applying_indices = arr_slice[indices_slice] + for j in range(arr_slice_after_applying_indices.size): + ( + result._buf.ptr + Int(indices_slice_offsets[j]) + ).init_pointee_copy( + arr_slice_after_applying_indices._buf.ptr[j] + ) + + return result diff --git a/numojo/routines/io/__init__.mojo b/numojo/routines/io/__init__.mojo index 6f95cf23..4aff90cd 100644 --- a/numojo/routines/io/__init__.mojo +++ b/numojo/routines/io/__init__.mojo @@ -6,7 +6,5 @@ from .files import ( from .formatting import ( format_floating_scientific, PrintOptions, - printoptions, - GLOBAL_PRINT_OPTIONS, set_printoptions, ) diff --git a/numojo/routines/io/formatting.mojo b/numojo/routines/io/formatting.mojo index ed983a9c..32295a83 100644 --- a/numojo/routines/io/formatting.mojo +++ b/numojo/routines/io/formatting.mojo @@ -20,25 +20,7 @@ alias DEFAULT_FORMATTED_WIDTH = 8 alias DEFAULT_EXPONENT_THRESHOLD = 4 alias DEFAULT_SUPPRESS_SCIENTIFIC = False -var GLOBAL_PRINT_OPTIONS: PrintOptions = PrintOptions( - precision=DEFAULT_PRECISION, - suppress_small=DEFAULT_SUPPRESS_SMALL, - separator=DEFAULT_SEPARATOR, - padding=DEFAULT_PADDING, - threshold=DEFAULT_THRESHOLD, - line_width=DEFAULT_LINE_WIDTH, - edge_items=DEFAULT_EDGE_ITEMS, - sign=DEFAULT_SIGN, - float_format=DEFAULT_FLOAT_FORMAT, - complex_format=DEFAULT_COMPLEX_FORMAT, - nan_string=DEFAULT_NAN_STRING, - inf_string=DEFAULT_INF_STRING, - formatted_width=DEFAULT_FORMATTED_WIDTH, - exponent_threshold=DEFAULT_EXPONENT_THRESHOLD, - suppress_scientific=DEFAULT_SUPPRESS_SCIENTIFIC, -) - -alias printoptions = PrintOptions +alias GLOBAL_PRINT_OPTIONS = PrintOptions() @value @@ -146,7 +128,8 @@ struct PrintOptions: self.suppress_scientific = suppress_scientific fn __enter__(mut self) -> Self: - GLOBAL_PRINT_OPTIONS.set_options( + var default_print_options = PrintOptions() + default_print_options.set_options( precision=self.precision, suppress_small=self.suppress_small, separator=self.separator, @@ -163,10 +146,11 @@ struct PrintOptions: exponent_threshold=self.exponent_threshold, suppress_scientific=self.suppress_scientific, ) - return GLOBAL_PRINT_OPTIONS + return default_print_options fn __exit__(mut self): - GLOBAL_PRINT_OPTIONS.set_options( + var default_print_options = PrintOptions() + default_print_options.set_options( precision=DEFAULT_PRECISION, suppress_small=DEFAULT_SUPPRESS_SMALL, separator=DEFAULT_SEPARATOR, @@ -192,7 +176,8 @@ fn set_printoptions( padding: String = DEFAULT_PADDING, edge_items: Int = DEFAULT_EDGE_ITEMS, ): - GLOBAL_PRINT_OPTIONS.set_options( + var default_print_options = PrintOptions() + default_print_options.set_options( precision, suppress_small, separator, @@ -204,7 +189,14 @@ fn set_printoptions( # FIXME: fix the problem where precision > number of digits in the mantissa results in a not so exact value. fn format_floating_scientific[ dtype: DType = DType.float64 -](x: Scalar[dtype], precision: Int = 10, sign: Bool = False) raises -> String: +]( + x: Scalar[dtype], + precision: Int = 10, + sign: Bool = False, + suppress_scientific: Bool = False, + exponent_threshold: Int = 4, + formatted_width: Int = 8, +) raises -> String: """ Format a float in scientific notation. @@ -218,6 +210,11 @@ fn format_floating_scientific[ x: The float to format. precision: The number of decimal places to include in the mantissa. sign: Whether to include the sign of the float in the result. Defaults to False. + suppress_scientific: Whether to suppress scientific notation for small numbers. + Defaults to False. + exponent_threshold: The threshold for suppressing scientific notation. + Defaults to 4. + formatted_width: The width of the formatted string. Defaults to 8. Returns: A string representation of the float in scientific notation. @@ -235,10 +232,6 @@ fn format_floating_scientific[ raise Error("Precision must be a non-negative integer.") try: - var suppress_scientific = GLOBAL_PRINT_OPTIONS.suppress_scientific - var exponent_threshold = GLOBAL_PRINT_OPTIONS.exponent_threshold - var formatted_width = GLOBAL_PRINT_OPTIONS.formatted_width - if x == 0: if sign: var result: String = "+0." + "0" * precision + "e+00" @@ -290,7 +283,12 @@ fn format_floating_scientific[ fn format_floating_precision[ dtype: DType -](value: Scalar[dtype], precision: Int, sign: Bool = False) raises -> String: +]( + value: Scalar[dtype], + precision: Int, + sign: Bool = False, + suppress_small: Bool = False, +) raises -> String: """ Format a floating-point value to the specified precision. @@ -299,6 +297,7 @@ fn format_floating_precision[ precision: The number of decimal places to include. sign: Whether to include the sign of the float in the result. Defaults to False. + suppress_small: Whether to suppress small numbers. Defaults to False. Returns: The formatted value as a string. @@ -316,7 +315,6 @@ fn format_floating_precision[ if precision < 0: raise Error("Precision must be a non-negative integer.") - var suppress_small = GLOBAL_PRINT_OPTIONS.suppress_small if suppress_small and abs(value) < 1e-10: var result: String = String("0.") for _ in range(precision): @@ -349,12 +347,19 @@ fn format_floating_precision[ fn format_floating_precision[ cdtype: CDType, dtype: DType -](value: ComplexSIMD[cdtype, dtype=dtype]) raises -> String: +]( + value: ComplexSIMD[cdtype, dtype=dtype], + precision: Int = 4, + sign: Bool = False, +) raises -> String: """ Format a complex floating-point value to the specified precision. Args: value: The complex value to format. + precision: The number of decimal places to include. + sign: Whether to include the sign of the float in the result. + Defaults to False. Returns: The formatted value as a string. @@ -363,9 +368,6 @@ fn format_floating_precision[ Error: If the complex value cannot be formatted. """ try: - var precision = GLOBAL_PRINT_OPTIONS.precision - var sign = GLOBAL_PRINT_OPTIONS.sign - return ( "(" + format_floating_precision(value.re, precision, sign) diff --git a/tests/routines/test_indexing.mojo b/tests/routines/test_indexing.mojo index a443e9f8..d0d5e8b9 100644 --- a/tests/routines/test_indexing.mojo +++ b/tests/routines/test_indexing.mojo @@ -63,3 +63,236 @@ fn test_compress() raises: np.compress(np.array([0, 1]), anp, axis=2), "`compress` 3-d array by axis 2 is broken", ) + + +fn test_take_along_axis() raises: + var np = Python.import_module("numpy") + + # Test 1-D array + var a1d = nm.arange[i8](10) + var a1d_np = a1d.to_numpy() + var indices1d = nm.array[intp]("[2, 3, 1, 8]") + var indices1d_np = indices1d.to_numpy() + + check( + nm.indexing.take_along_axis(a1d, indices1d, axis=0), + np.take_along_axis(a1d_np, indices1d_np, axis=0), + "`take_along_axis` with 1-D array is broken", + ) + + # Test 2-D array with axis=0 + var a2d = nm.arange[i8](12).reshape(Shape(3, 4)) + var a2d_np = a2d.to_numpy() + var indices2d_0 = nm.array[intp]("[[0, 1, 2, 0], [1, 2, 0, 1]]") + var indices2d_0_np = indices2d_0.to_numpy() + + check( + nm.indexing.take_along_axis(a2d, indices2d_0, axis=0), + np.take_along_axis(a2d_np, indices2d_0_np, axis=0), + "`take_along_axis` with 2-D array on axis=0 is broken", + ) + + # Test 2-D array with axis=1 + var indices2d_1 = nm.array[intp]( + "[[3, 0, 2, 1], [1, 3, 0, 0], [2, 1, 0, 3]]" + ) + var indices2d_1_np = indices2d_1.to_numpy() + + check( + nm.indexing.take_along_axis(a2d, indices2d_1, axis=1), + np.take_along_axis(a2d_np, indices2d_1_np, axis=1), + "`take_along_axis` with 2-D array on axis=1 is broken", + ) + + # Test 3-D array + var a3d = nm.arange[i8](24).reshape(Shape(2, 3, 4)) + var a3d_np = a3d.to_numpy() + + # Test with axis=0 + var indices3d_0 = nm.zeros[intp](Shape(1, 3, 4)) + var indices3d_0_np = indices3d_0.to_numpy() + + check( + nm.indexing.take_along_axis(a3d, indices3d_0, axis=0), + np.take_along_axis(a3d_np, indices3d_0_np, axis=0), + "`take_along_axis` with 3-D array on axis=0 is broken", + ) + + # Test with axis=1 + var indices3d_1 = nm.array[intp]( + "[[[0, 1, 0, 2], [2, 1, 0, 1], [1, 2, 2, 0]], [[1, 0, 1, 2], [0, 2, 1," + " 0], [2, 0, 0, 1]]]" + ) + var indices3d_1_np = indices3d_1.to_numpy() + + check( + nm.indexing.take_along_axis(a3d, indices3d_1, axis=1), + np.take_along_axis(a3d_np, indices3d_1_np, axis=1), + "`take_along_axis` with 3-D array on axis=1 is broken", + ) + + # Test with axis=2 + var indices3d_2 = nm.array[intp]( + "[[[2, 0, 3, 1], [1, 3, 0, 2], [3, 1, 2, 0]], [[0, 2, 1, 3], [2, 0, 3," + " 1], [1, 3, 0, 2]]]" + ) + var indices3d_2_np = indices3d_2.to_numpy() + + check( + nm.indexing.take_along_axis(a3d, indices3d_2, axis=2), + np.take_along_axis(a3d_np, indices3d_2_np, axis=2), + "`take_along_axis` with 3-D array on axis=2 is broken", + ) + + # Test with negative axis + check( + nm.indexing.take_along_axis(a3d, indices3d_2, axis=-1), + np.take_along_axis(a3d_np, indices3d_2_np, axis=-1), + "`take_along_axis` with negative axis is broken", + ) + + # Test cases where indices shape matches array shape except on the target axis + + # For 2D array (3, 4) + var a2d_test = nm.arange[i8](12).reshape(Shape(3, 4)) + var a2d_test_np = a2d_test.to_numpy() + + # For axis=0, using indices of shape (2, 4) - different first dim, same second dim + var indices2d_axis0 = nm.array[intp]("[[0, 1, 2, 0], [1, 0, 2, 1]]") + var indices2d_axis0_np = indices2d_axis0.to_numpy() + + check( + nm.indexing.take_along_axis(a2d_test, indices2d_axis0, axis=0), + np.take_along_axis(a2d_test_np, indices2d_axis0_np, axis=0), + ( + "`take_along_axis` with shape-aligned indices on axis=0 for 2D" + " array is broken" + ), + ) + + # For axis=1, using indices of shape (3, 2) - same first dim, different second dim + var indices2d_axis1 = nm.array[intp]("[[0, 3], [2, 1], [1, 3]]") + var indices2d_axis1_np = indices2d_axis1.to_numpy() + + check( + nm.indexing.take_along_axis(a2d_test, indices2d_axis1, axis=1), + np.take_along_axis(a2d_test_np, indices2d_axis1_np, axis=1), + ( + "`take_along_axis` with shape-aligned indices on axis=1 for 2D" + " array is broken" + ), + ) + + # For 3D array (2, 3, 4) + # Reshape and get base numpy array + var a3d_test = nm.arange[i8](24).reshape(Shape(2, 3, 4)) + var a3d_test_np = a3d_test.to_numpy() + + # For axis=0, indices of shape (1, 3, 4) - same shape except dim 0 + var ind_axis0 = nm.zeros[intp](Shape(1, 3, 4)) + var ind_axis0_np = ind_axis0.to_numpy() + + check( + nm.indexing.take_along_axis(a3d_test, ind_axis0, axis=0), + np.take_along_axis(a3d_test_np, ind_axis0_np, axis=0), + ( + "`take_along_axis` with shape-aligned indices on axis=0 for 3D" + " array is broken" + ), + ) + + # For axis=2, indices of shape (2, 3, 2) - same shape except dim 2 + var ind_axis2 = nm.array[intp]( + "[[[0, 3], [2, 1], [3, 0]], [[1, 2], [0, 3], [2, 1]]]" + ) + var ind_axis2_np = ind_axis2.to_numpy() + + check( + nm.indexing.take_along_axis(a3d_test, ind_axis2, axis=2), + np.take_along_axis(a3d_test_np, ind_axis2_np, axis=2), + ( + "`take_along_axis` with shape-aligned indices on axis=2 for 3D" + " array is broken" + ), + ) + + +fn test_take_along_axis_fortran_order() raises: + var np = Python.import_module("numpy") + + # Create 3-D F-order array + var a3d_f = nm.arange[i8](24).reshape(Shape(2, 3, 4), order="F") + var a3d_f_np = a3d_f.to_numpy() + + # Test with axis=0 + var indices3d_0 = nm.zeros[intp](Shape(1, 3, 4)) + var indices3d_0_np = indices3d_0.to_numpy() + + check( + nm.indexing.take_along_axis(a3d_f, indices3d_0, axis=0), + np.take_along_axis(a3d_f_np, indices3d_0_np, axis=0), + "`take_along_axis` with 3-D F-order array on axis=0 is broken", + ) + + # Test with axis=1 + var indices3d_1 = nm.array[intp]( + "[[[0, 1, 0, 2], [2, 1, 0, 1], [1, 2, 2, 0]], [[1, 0, 1, 2], [0, 2, 1," + " 0], [2, 0, 0, 1]]]" + ) + var indices3d_1_np = indices3d_1.to_numpy() + + check( + nm.indexing.take_along_axis(a3d_f, indices3d_1, axis=1), + np.take_along_axis(a3d_f_np, indices3d_1_np, axis=1), + "`take_along_axis` with 3-D F-order array on axis=1 is broken", + ) + + # Test with axis=2 + var indices3d_2 = nm.array[intp]( + "[[[2, 0, 3, 1], [1, 3, 0, 2], [3, 1, 2, 0]], [[0, 2, 1, 3], [2, 0, 3," + " 1], [1, 3, 0, 2]]]" + ) + var indices3d_2_np = indices3d_2.to_numpy() + + check( + nm.indexing.take_along_axis(a3d_f, indices3d_2, axis=2), + np.take_along_axis(a3d_f_np, indices3d_2_np, axis=2), + "`take_along_axis` with 3-D F-order array on axis=2 is broken", + ) + + # Test with argsort use case on each axis + var sorted_indices_0 = nm.argsort(a3d_f, axis=0) + var sorted_indices_0_np = sorted_indices_0.to_numpy() + + check( + nm.indexing.take_along_axis(a3d_f, sorted_indices_0, axis=0), + np.take_along_axis(a3d_f_np, sorted_indices_0_np, axis=0), + ( + "`take_along_axis` with argsorted indices on axis=0 for F-order" + " array is broken" + ), + ) + + var sorted_indices_1 = nm.argsort(a3d_f, axis=1) + var sorted_indices_1_np = sorted_indices_1.to_numpy() + + check( + nm.indexing.take_along_axis(a3d_f, sorted_indices_1, axis=1), + np.take_along_axis(a3d_f_np, sorted_indices_1_np, axis=1), + ( + "`take_along_axis` with argsorted indices on axis=1 for F-order" + " array is broken" + ), + ) + + var sorted_indices_2 = nm.argsort(a3d_f, axis=2) + var sorted_indices_2_np = sorted_indices_2.to_numpy() + + check( + nm.indexing.take_along_axis(a3d_f, sorted_indices_2, axis=2), + np.take_along_axis(a3d_f_np, sorted_indices_2_np, axis=2), + ( + "`take_along_axis` with argsorted indices on axis=2 for F-order" + " array is broken" + ), + )