Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ magic.lock

# Miscellaneous files
mojo
/numojo.mojopkg
numojo.mojopkg
/bench.mojo
/test*.mojo
/test*.ipynb
Expand Down
16 changes: 8 additions & 8 deletions mojoproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ license = "Apache-2.0"
readme = "README.MD"

[tasks]
# compile the package
package = "magic run mojo package numojo"
# compile the package and copy it to the tests folder
package = "magic run mojo package numojo && cp numojo.mojopkg tests/"
p = "clear && magic run package"

# format the package
format = "magic run mojo format ./ && magic run mojo format docs/readthedocs/docs.py"
format = "magic run mojo format ./"

# test whether tests pass and the package can be built
test = "magic run mojo test tests -I ./ -I ./tests/ && magic run package"
# test whether tests pass on the built package
test = "magic run package && magic run mojo test tests -I ./tests/"
t = "clear && magic run test"

# run individual tests to avoid overheat
Expand All @@ -40,8 +40,8 @@ test_statistics = "magic run mojo test tests/routines/test_statistics.mojo -I ./
test_sorting = "magic run mojo test tests/routines/test_sorting.mojo -I ./ -I ./tests/"

# run all final checks before a commit
final = "magic run test && magic run format && magic run package"
f = "clear && magic run test && magic run format && magic run package"
final = "magic run format && magic run test"
f = "clear && magic run final"

# Automatically Generate doc pages
doc_pages = "mojo doc numojo/ -o docs.json"
Expand All @@ -50,7 +50,7 @@ doc_pages = "mojo doc numojo/ -o docs.json"
release = "clear && magic run final && magic run doc_pages"

[dependencies]
max = "=25.1"
max = "=25.1.1"
python = ">=3.11"
numpy = ">=1.19"
scipy = ">=1.14"
4 changes: 2 additions & 2 deletions numojo/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ from numojo.routines.io import (
loadtxt,
savetxt,
)
from numojo.routines.io import printoptions, set_printoptions
from numojo.routines.io import set_printoptions

from numojo.routines import linalg
from numojo.routines.linalg.misc import diagonal
Expand Down Expand Up @@ -181,7 +181,7 @@ from numojo.routines.creation import (
)

from numojo.routines import indexing
from numojo.routines.indexing import where, compress
from numojo.routines.indexing import where, compress, take_along_axis

from numojo.routines.functional import apply_along_axis

Expand Down
1 change: 0 additions & 1 deletion numojo/core/complex/complex_ndarray.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ from numojo.routines.io.formatting import (
format_floating_scientific,
format_value,
PrintOptions,
printoptions,
GLOBAL_PRINT_OPTIONS,
)
import numojo.routines.linalg as linalg
Expand Down
123 changes: 123 additions & 0 deletions numojo/routines/indexing.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -232,3 +232,126 @@ fn compress[

else:
return compress(condition, ravel(a), axis=0)


fn take_along_axis[
dtype: DType, //,
](
arr: NDArray[dtype], indices: NDArray[DType.index], axis: Int = 0
) raises -> NDArray[dtype]:
"""
Takes values from the input array along the given axis based on indices.

Raises:
Error: If the axis is out of bounds for the given array.
Error: If the ndim of arr and indices are not the same.
Error: If the shape of indices does not match the shape of the
input array except along the given axis.

Parameters:
dtype: DType of the input array.

Args:
arr: The source array.
indices: The indices array.
axis: The axis along which to take values. Default is 0.

Returns:
An array with the same shape as indices with values taken from the
input array along the given axis.

Examples:

```console
> var a = nm.arange[i8](12).reshape(Shape(3, 4))
> print(a)
[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]
> ind = nm.array[intp]("[[0, 1, 2, 0], [1, 0, 2, 1]]")
> print(ind)
[[0 1 2 0]
[1 0 2 1]]
> print(nm.indexing.take_along_axis(a, ind, axis=0))
[[ 0 5 10 3]
[ 4 1 10 7]]
```
.
"""
var normalized_axis = axis
if normalized_axis < 0:
normalized_axis = arr.ndim + normalized_axis
if (normalized_axis >= arr.ndim) or (normalized_axis < 0):
raise Error(
String(
"\nError in `take_along_axis`: Axis {} is out of bound for"
" array with {} dimensions"
).format(axis, arr.ndim)
)

# Check if the ndim of arr and indices are same
if arr.ndim != indices.ndim:
raise Error(
String(
"\nError in `take_along_axis`: The ndim of arr and indices must"
" be same. Got {} and {}.".format(arr.ndim, indices.ndim)
)
)

# broadcast indices to the shape of arr if necessary
# When broadcasting, the shape of indices must match the shape of arr
# except along the axis

var broadcasted_indices = indices

if arr.shape != indices.shape:
var arr_shape_new = arr.shape
arr_shape_new[normalized_axis] = indices.shape[normalized_axis]

try:
broadcasted_indices = numojo.broadcast_to(indices, arr_shape_new)
except e:
raise Error(
String(
"\nError in `take_along_axis`: Shape of indices must match"
" shape of array except along the given axis. "
+ String(e)
)
)

# Create output array with same shape as broadcasted_indices
var result = NDArray[dtype](Shape(broadcasted_indices.shape))

var arr_iterator = arr.iter_along_axis(normalized_axis)
var indices_iterator = broadcasted_indices.iter_along_axis(normalized_axis)
var length_of_iterator = result.size // result.shape[normalized_axis]

if normalized_axis == arr.ndim - 1:
# If axis is the last axis, the data is contiguous.
for i in range(length_of_iterator):
var arr_slice = arr_iterator.ith(i)
var indices_slice = indices_iterator.ith(i)
var arr_slice_after_applying_indices = arr_slice[indices_slice]
memcpy(
result._buf.ptr + i * result.shape[normalized_axis],
arr_slice_after_applying_indices._buf.ptr,
result.shape[normalized_axis],
)
else:
# If axis is not the last axis, the data is not contiguous.
for i in range(length_of_iterator):
var indices_slice_offsets: NDArray[DType.index]
var indices_slice: NDArray[DType.index]
indices_slice_offsets, indices_slice = (
indices_iterator.ith_with_offsets(i)
)
var arr_slice = arr_iterator.ith(i)
var arr_slice_after_applying_indices = arr_slice[indices_slice]
for j in range(arr_slice_after_applying_indices.size):
(
result._buf.ptr + Int(indices_slice_offsets[j])
).init_pointee_copy(
arr_slice_after_applying_indices._buf.ptr[j]
)

return result
2 changes: 0 additions & 2 deletions numojo/routines/io/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,5 @@ from .files import (
from .formatting import (
format_floating_scientific,
PrintOptions,
printoptions,
GLOBAL_PRINT_OPTIONS,
set_printoptions,
)
70 changes: 36 additions & 34 deletions numojo/routines/io/formatting.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,7 @@ alias DEFAULT_FORMATTED_WIDTH = 8
alias DEFAULT_EXPONENT_THRESHOLD = 4
alias DEFAULT_SUPPRESS_SCIENTIFIC = False

var GLOBAL_PRINT_OPTIONS: PrintOptions = PrintOptions(
precision=DEFAULT_PRECISION,
suppress_small=DEFAULT_SUPPRESS_SMALL,
separator=DEFAULT_SEPARATOR,
padding=DEFAULT_PADDING,
threshold=DEFAULT_THRESHOLD,
line_width=DEFAULT_LINE_WIDTH,
edge_items=DEFAULT_EDGE_ITEMS,
sign=DEFAULT_SIGN,
float_format=DEFAULT_FLOAT_FORMAT,
complex_format=DEFAULT_COMPLEX_FORMAT,
nan_string=DEFAULT_NAN_STRING,
inf_string=DEFAULT_INF_STRING,
formatted_width=DEFAULT_FORMATTED_WIDTH,
exponent_threshold=DEFAULT_EXPONENT_THRESHOLD,
suppress_scientific=DEFAULT_SUPPRESS_SCIENTIFIC,
)

alias printoptions = PrintOptions
alias GLOBAL_PRINT_OPTIONS = PrintOptions()


@value
Expand Down Expand Up @@ -146,7 +128,8 @@ struct PrintOptions:
self.suppress_scientific = suppress_scientific

fn __enter__(mut self) -> Self:
GLOBAL_PRINT_OPTIONS.set_options(
var default_print_options = PrintOptions()
default_print_options.set_options(
precision=self.precision,
suppress_small=self.suppress_small,
separator=self.separator,
Expand All @@ -163,10 +146,11 @@ struct PrintOptions:
exponent_threshold=self.exponent_threshold,
suppress_scientific=self.suppress_scientific,
)
return GLOBAL_PRINT_OPTIONS
return default_print_options

fn __exit__(mut self):
GLOBAL_PRINT_OPTIONS.set_options(
var default_print_options = PrintOptions()
default_print_options.set_options(
precision=DEFAULT_PRECISION,
suppress_small=DEFAULT_SUPPRESS_SMALL,
separator=DEFAULT_SEPARATOR,
Expand All @@ -192,7 +176,8 @@ fn set_printoptions(
padding: String = DEFAULT_PADDING,
edge_items: Int = DEFAULT_EDGE_ITEMS,
):
GLOBAL_PRINT_OPTIONS.set_options(
var default_print_options = PrintOptions()
default_print_options.set_options(
precision,
suppress_small,
separator,
Expand All @@ -204,7 +189,14 @@ fn set_printoptions(
# FIXME: fix the problem where precision > number of digits in the mantissa results in a not so exact value.
fn format_floating_scientific[
dtype: DType = DType.float64
](x: Scalar[dtype], precision: Int = 10, sign: Bool = False) raises -> String:
](
x: Scalar[dtype],
precision: Int = 10,
sign: Bool = False,
suppress_scientific: Bool = False,
exponent_threshold: Int = 4,
formatted_width: Int = 8,
) raises -> String:
"""
Format a float in scientific notation.

Expand All @@ -218,6 +210,11 @@ fn format_floating_scientific[
x: The float to format.
precision: The number of decimal places to include in the mantissa.
sign: Whether to include the sign of the float in the result. Defaults to False.
suppress_scientific: Whether to suppress scientific notation for small numbers.
Defaults to False.
exponent_threshold: The threshold for suppressing scientific notation.
Defaults to 4.
formatted_width: The width of the formatted string. Defaults to 8.

Returns:
A string representation of the float in scientific notation.
Expand All @@ -235,10 +232,6 @@ fn format_floating_scientific[
raise Error("Precision must be a non-negative integer.")

try:
var suppress_scientific = GLOBAL_PRINT_OPTIONS.suppress_scientific
var exponent_threshold = GLOBAL_PRINT_OPTIONS.exponent_threshold
var formatted_width = GLOBAL_PRINT_OPTIONS.formatted_width

if x == 0:
if sign:
var result: String = "+0." + "0" * precision + "e+00"
Expand Down Expand Up @@ -290,7 +283,12 @@ fn format_floating_scientific[

fn format_floating_precision[
dtype: DType
](value: Scalar[dtype], precision: Int, sign: Bool = False) raises -> String:
](
value: Scalar[dtype],
precision: Int,
sign: Bool = False,
suppress_small: Bool = False,
) raises -> String:
"""
Format a floating-point value to the specified precision.

Expand All @@ -299,6 +297,7 @@ fn format_floating_precision[
precision: The number of decimal places to include.
sign: Whether to include the sign of the float in the result.
Defaults to False.
suppress_small: Whether to suppress small numbers. Defaults to False.

Returns:
The formatted value as a string.
Expand All @@ -316,7 +315,6 @@ fn format_floating_precision[
if precision < 0:
raise Error("Precision must be a non-negative integer.")

var suppress_small = GLOBAL_PRINT_OPTIONS.suppress_small
if suppress_small and abs(value) < 1e-10:
var result: String = String("0.")
for _ in range(precision):
Expand Down Expand Up @@ -349,12 +347,19 @@ fn format_floating_precision[

fn format_floating_precision[
cdtype: CDType, dtype: DType
](value: ComplexSIMD[cdtype, dtype=dtype]) raises -> String:
](
value: ComplexSIMD[cdtype, dtype=dtype],
precision: Int = 4,
sign: Bool = False,
) raises -> String:
"""
Format a complex floating-point value to the specified precision.

Args:
value: The complex value to format.
precision: The number of decimal places to include.
sign: Whether to include the sign of the float in the result.
Defaults to False.

Returns:
The formatted value as a string.
Expand All @@ -363,9 +368,6 @@ fn format_floating_precision[
Error: If the complex value cannot be formatted.
"""
try:
var precision = GLOBAL_PRINT_OPTIONS.precision
var sign = GLOBAL_PRINT_OPTIONS.sign

return (
"("
+ format_floating_precision(value.re, precision, sign)
Expand Down
Loading