Skip to content

Commit f0dea37

Browse files
[fix] Hotfix the bug that numojo crashes on mojopkg (#227) Numojo 0.6.1
This pull request includes several changes to the `mojoproject.toml` file and the `numojo` package, focusing on improving the build process, testing, and formatting functionalities. Additionally, there are updates to the handling of print options in the `numojo` package. ### Improvements to build and testing process: * [`mojoproject.toml`](diffhunk://#diff-1b0ef62120bccf4c05a76e60e13fe686f33eb92d5897b3f05e0c3f8d737fc5c0L19-R27): Updated the `package` task to include copying the package to the tests folder and modified the `test` task to test the built package. * [`mojoproject.toml`](diffhunk://#diff-1b0ef62120bccf4c05a76e60e13fe686f33eb92d5897b3f05e0c3f8d737fc5c0L43-R44): Simplified the `final` task to run format and test commands, and updated the `f` alias to run the `final` task. ### Dependency updates: * [`mojoproject.toml`](diffhunk://#diff-1b0ef62120bccf4c05a76e60e13fe686f33eb92d5897b3f05e0c3f8d737fc5c0L53-R53): Updated the `max` dependency version to `25.1.1`. ### Codebase simplification: * `numojo/__init__.mojo`, `numojo/core/complex/complex_ndarray.mojo`, `numojo/routines/io/__init__.mojo`: Removed the unused `printoptions` import. [[1]](diffhunk://#diff-c866aa2a9b7b267a54550b85f8f3bd958d31806072b3f8449cfeb3e1b04fb954L65-R65) [[2]](diffhunk://#diff-50999070ae6817e2537de378c19bc32ee0ee691681f0c39492d8e9fc0d540a35L37) [[3]](diffhunk://#diff-44cb945c42164cf66f26636ae61d288b6b52ed68e756a3e09b81725c27a97b96L9-L10) * [`numojo/routines/io/formatting.mojo`](diffhunk://#diff-ebe0524c71023004d26ae3e618bbaabd5177b5261be63254bf3d483f16af991aL23-R23): Replaced `GLOBAL_PRINT_OPTIONS` with a local `default_print_options` instance in several functions and methods, and added new parameters to the `format_floating_scientific` and `format_floating_precision` functions for better control over formatting behavior. [[1]](diffhunk://#diff-ebe0524c71023004d26ae3e618bbaabd5177b5261be63254bf3d483f16af991aL23-R23) [[2]](diffhunk://#diff-ebe0524c71023004d26ae3e618bbaabd5177b5261be63254bf3d483f16af991aL149-R132) [[3]](diffhunk://#diff-ebe0524c71023004d26ae3e618bbaabd5177b5261be63254bf3d483f16af991aL166-R153) [[4]](diffhunk://#diff-ebe0524c71023004d26ae3e618bbaabd5177b5261be63254bf3d483f16af991aL195-R180) [[5]](diffhunk://#diff-ebe0524c71023004d26ae3e618bbaabd5177b5261be63254bf3d483f16af991aL207-R199) [[6]](diffhunk://#diff-ebe0524c71023004d26ae3e618bbaabd5177b5261be63254bf3d483f16af991aR213-R217) [[7]](diffhunk://#diff-ebe0524c71023004d26ae3e618bbaabd5177b5261be63254bf3d483f16af991aL238-L241) [[8]](diffhunk://#diff-ebe0524c71023004d26ae3e618bbaabd5177b5261be63254bf3d483f16af991aL293-R291) [[9]](diffhunk://#diff-ebe0524c71023004d26ae3e618bbaabd5177b5261be63254bf3d483f16af991aR300) [[10]](diffhunk://#diff-ebe0524c71023004d26ae3e618bbaabd5177b5261be63254bf3d483f16af991aL319) [[11]](diffhunk://#diff-ebe0524c71023004d26ae3e618bbaabd5177b5261be63254bf3d483f16af991aL352-R362) [[12]](diffhunk://#diff-ebe0524c71023004d26ae3e618bbaabd5177b5261be63254bf3d483f16af991aL366-L368) --------- Co-authored-by: MadAlex1997 <[email protected]>
1 parent 23f316c commit f0dea37

File tree

8 files changed

+403
-48
lines changed

8 files changed

+403
-48
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ magic.lock
2727

2828
# Miscellaneous files
2929
mojo
30-
/numojo.mojopkg
30+
numojo.mojopkg
3131
/bench.mojo
3232
/test*.mojo
3333
/test*.ipynb

mojoproject.toml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ license = "Apache-2.0"
1616
readme = "README.MD"
1717

1818
[tasks]
19-
# compile the package
20-
package = "magic run mojo package numojo"
19+
# compile the package and copy it to the tests folder
20+
package = "magic run mojo package numojo && cp numojo.mojopkg tests/"
2121
p = "clear && magic run package"
2222

2323
# format the package
24-
format = "magic run mojo format ./ && magic run mojo format docs/readthedocs/docs.py"
24+
format = "magic run mojo format ./"
2525

26-
# test whether tests pass and the package can be built
27-
test = "magic run mojo test tests -I ./ -I ./tests/ && magic run package"
26+
# test whether tests pass on the built package
27+
test = "magic run package && magic run mojo test tests -I ./tests/"
2828
t = "clear && magic run test"
2929

3030
# run individual tests to avoid overheat
@@ -40,8 +40,8 @@ test_statistics = "magic run mojo test tests/routines/test_statistics.mojo -I ./
4040
test_sorting = "magic run mojo test tests/routines/test_sorting.mojo -I ./ -I ./tests/"
4141

4242
# run all final checks before a commit
43-
final = "magic run test && magic run format && magic run package"
44-
f = "clear && magic run test && magic run format && magic run package"
43+
final = "magic run format && magic run test"
44+
f = "clear && magic run final"
4545

4646
# Automatically Generate doc pages
4747
doc_pages = "mojo doc numojo/ -o docs.json"
@@ -50,7 +50,7 @@ doc_pages = "mojo doc numojo/ -o docs.json"
5050
release = "clear && magic run final && magic run doc_pages"
5151

5252
[dependencies]
53-
max = "=25.1"
53+
max = "=25.1.1"
5454
python = ">=3.11"
5555
numpy = ">=1.19"
5656
scipy = ">=1.14"

numojo/__init__.mojo

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ from numojo.routines.io import (
6262
loadtxt,
6363
savetxt,
6464
)
65-
from numojo.routines.io import printoptions, set_printoptions
65+
from numojo.routines.io import set_printoptions
6666

6767
from numojo.routines import linalg
6868
from numojo.routines.linalg.misc import diagonal
@@ -181,7 +181,7 @@ from numojo.routines.creation import (
181181
)
182182

183183
from numojo.routines import indexing
184-
from numojo.routines.indexing import where, compress
184+
from numojo.routines.indexing import where, compress, take_along_axis
185185

186186
from numojo.routines.functional import apply_along_axis
187187

numojo/core/complex/complex_ndarray.mojo

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ from numojo.routines.io.formatting import (
3434
format_floating_scientific,
3535
format_value,
3636
PrintOptions,
37-
printoptions,
3837
GLOBAL_PRINT_OPTIONS,
3938
)
4039
import numojo.routines.linalg as linalg

numojo/routines/indexing.mojo

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,126 @@ fn compress[
232232

233233
else:
234234
return compress(condition, ravel(a), axis=0)
235+
236+
237+
fn take_along_axis[
238+
dtype: DType, //,
239+
](
240+
arr: NDArray[dtype], indices: NDArray[DType.index], axis: Int = 0
241+
) raises -> NDArray[dtype]:
242+
"""
243+
Takes values from the input array along the given axis based on indices.
244+
245+
Raises:
246+
Error: If the axis is out of bounds for the given array.
247+
Error: If the ndim of arr and indices are not the same.
248+
Error: If the shape of indices does not match the shape of the
249+
input array except along the given axis.
250+
251+
Parameters:
252+
dtype: DType of the input array.
253+
254+
Args:
255+
arr: The source array.
256+
indices: The indices array.
257+
axis: The axis along which to take values. Default is 0.
258+
259+
Returns:
260+
An array with the same shape as indices with values taken from the
261+
input array along the given axis.
262+
263+
Examples:
264+
265+
```console
266+
> var a = nm.arange[i8](12).reshape(Shape(3, 4))
267+
> print(a)
268+
[[ 0 1 2 3]
269+
[ 4 5 6 7]
270+
[ 8 9 10 11]]
271+
> ind = nm.array[intp]("[[0, 1, 2, 0], [1, 0, 2, 1]]")
272+
> print(ind)
273+
[[0 1 2 0]
274+
[1 0 2 1]]
275+
> print(nm.indexing.take_along_axis(a, ind, axis=0))
276+
[[ 0 5 10 3]
277+
[ 4 1 10 7]]
278+
```
279+
.
280+
"""
281+
var normalized_axis = axis
282+
if normalized_axis < 0:
283+
normalized_axis = arr.ndim + normalized_axis
284+
if (normalized_axis >= arr.ndim) or (normalized_axis < 0):
285+
raise Error(
286+
String(
287+
"\nError in `take_along_axis`: Axis {} is out of bound for"
288+
" array with {} dimensions"
289+
).format(axis, arr.ndim)
290+
)
291+
292+
# Check if the ndim of arr and indices are same
293+
if arr.ndim != indices.ndim:
294+
raise Error(
295+
String(
296+
"\nError in `take_along_axis`: The ndim of arr and indices must"
297+
" be same. Got {} and {}.".format(arr.ndim, indices.ndim)
298+
)
299+
)
300+
301+
# broadcast indices to the shape of arr if necessary
302+
# When broadcasting, the shape of indices must match the shape of arr
303+
# except along the axis
304+
305+
var broadcasted_indices = indices
306+
307+
if arr.shape != indices.shape:
308+
var arr_shape_new = arr.shape
309+
arr_shape_new[normalized_axis] = indices.shape[normalized_axis]
310+
311+
try:
312+
broadcasted_indices = numojo.broadcast_to(indices, arr_shape_new)
313+
except e:
314+
raise Error(
315+
String(
316+
"\nError in `take_along_axis`: Shape of indices must match"
317+
" shape of array except along the given axis. "
318+
+ String(e)
319+
)
320+
)
321+
322+
# Create output array with same shape as broadcasted_indices
323+
var result = NDArray[dtype](Shape(broadcasted_indices.shape))
324+
325+
var arr_iterator = arr.iter_along_axis(normalized_axis)
326+
var indices_iterator = broadcasted_indices.iter_along_axis(normalized_axis)
327+
var length_of_iterator = result.size // result.shape[normalized_axis]
328+
329+
if normalized_axis == arr.ndim - 1:
330+
# If axis is the last axis, the data is contiguous.
331+
for i in range(length_of_iterator):
332+
var arr_slice = arr_iterator.ith(i)
333+
var indices_slice = indices_iterator.ith(i)
334+
var arr_slice_after_applying_indices = arr_slice[indices_slice]
335+
memcpy(
336+
result._buf.ptr + i * result.shape[normalized_axis],
337+
arr_slice_after_applying_indices._buf.ptr,
338+
result.shape[normalized_axis],
339+
)
340+
else:
341+
# If axis is not the last axis, the data is not contiguous.
342+
for i in range(length_of_iterator):
343+
var indices_slice_offsets: NDArray[DType.index]
344+
var indices_slice: NDArray[DType.index]
345+
indices_slice_offsets, indices_slice = (
346+
indices_iterator.ith_with_offsets(i)
347+
)
348+
var arr_slice = arr_iterator.ith(i)
349+
var arr_slice_after_applying_indices = arr_slice[indices_slice]
350+
for j in range(arr_slice_after_applying_indices.size):
351+
(
352+
result._buf.ptr + Int(indices_slice_offsets[j])
353+
).init_pointee_copy(
354+
arr_slice_after_applying_indices._buf.ptr[j]
355+
)
356+
357+
return result

numojo/routines/io/__init__.mojo

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,5 @@ from .files import (
66
from .formatting import (
77
format_floating_scientific,
88
PrintOptions,
9-
printoptions,
10-
GLOBAL_PRINT_OPTIONS,
119
set_printoptions,
1210
)

numojo/routines/io/formatting.mojo

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,7 @@ alias DEFAULT_FORMATTED_WIDTH = 8
2020
alias DEFAULT_EXPONENT_THRESHOLD = 4
2121
alias DEFAULT_SUPPRESS_SCIENTIFIC = False
2222

23-
var GLOBAL_PRINT_OPTIONS: PrintOptions = PrintOptions(
24-
precision=DEFAULT_PRECISION,
25-
suppress_small=DEFAULT_SUPPRESS_SMALL,
26-
separator=DEFAULT_SEPARATOR,
27-
padding=DEFAULT_PADDING,
28-
threshold=DEFAULT_THRESHOLD,
29-
line_width=DEFAULT_LINE_WIDTH,
30-
edge_items=DEFAULT_EDGE_ITEMS,
31-
sign=DEFAULT_SIGN,
32-
float_format=DEFAULT_FLOAT_FORMAT,
33-
complex_format=DEFAULT_COMPLEX_FORMAT,
34-
nan_string=DEFAULT_NAN_STRING,
35-
inf_string=DEFAULT_INF_STRING,
36-
formatted_width=DEFAULT_FORMATTED_WIDTH,
37-
exponent_threshold=DEFAULT_EXPONENT_THRESHOLD,
38-
suppress_scientific=DEFAULT_SUPPRESS_SCIENTIFIC,
39-
)
40-
41-
alias printoptions = PrintOptions
23+
alias GLOBAL_PRINT_OPTIONS = PrintOptions()
4224

4325

4426
@value
@@ -146,7 +128,8 @@ struct PrintOptions:
146128
self.suppress_scientific = suppress_scientific
147129

148130
fn __enter__(mut self) -> Self:
149-
GLOBAL_PRINT_OPTIONS.set_options(
131+
var default_print_options = PrintOptions()
132+
default_print_options.set_options(
150133
precision=self.precision,
151134
suppress_small=self.suppress_small,
152135
separator=self.separator,
@@ -163,10 +146,11 @@ struct PrintOptions:
163146
exponent_threshold=self.exponent_threshold,
164147
suppress_scientific=self.suppress_scientific,
165148
)
166-
return GLOBAL_PRINT_OPTIONS
149+
return default_print_options
167150

168151
fn __exit__(mut self):
169-
GLOBAL_PRINT_OPTIONS.set_options(
152+
var default_print_options = PrintOptions()
153+
default_print_options.set_options(
170154
precision=DEFAULT_PRECISION,
171155
suppress_small=DEFAULT_SUPPRESS_SMALL,
172156
separator=DEFAULT_SEPARATOR,
@@ -192,7 +176,8 @@ fn set_printoptions(
192176
padding: String = DEFAULT_PADDING,
193177
edge_items: Int = DEFAULT_EDGE_ITEMS,
194178
):
195-
GLOBAL_PRINT_OPTIONS.set_options(
179+
var default_print_options = PrintOptions()
180+
default_print_options.set_options(
196181
precision,
197182
suppress_small,
198183
separator,
@@ -204,7 +189,14 @@ fn set_printoptions(
204189
# FIXME: fix the problem where precision > number of digits in the mantissa results in a not so exact value.
205190
fn format_floating_scientific[
206191
dtype: DType = DType.float64
207-
](x: Scalar[dtype], precision: Int = 10, sign: Bool = False) raises -> String:
192+
](
193+
x: Scalar[dtype],
194+
precision: Int = 10,
195+
sign: Bool = False,
196+
suppress_scientific: Bool = False,
197+
exponent_threshold: Int = 4,
198+
formatted_width: Int = 8,
199+
) raises -> String:
208200
"""
209201
Format a float in scientific notation.
210202
@@ -218,6 +210,11 @@ fn format_floating_scientific[
218210
x: The float to format.
219211
precision: The number of decimal places to include in the mantissa.
220212
sign: Whether to include the sign of the float in the result. Defaults to False.
213+
suppress_scientific: Whether to suppress scientific notation for small numbers.
214+
Defaults to False.
215+
exponent_threshold: The threshold for suppressing scientific notation.
216+
Defaults to 4.
217+
formatted_width: The width of the formatted string. Defaults to 8.
221218
222219
Returns:
223220
A string representation of the float in scientific notation.
@@ -235,10 +232,6 @@ fn format_floating_scientific[
235232
raise Error("Precision must be a non-negative integer.")
236233

237234
try:
238-
var suppress_scientific = GLOBAL_PRINT_OPTIONS.suppress_scientific
239-
var exponent_threshold = GLOBAL_PRINT_OPTIONS.exponent_threshold
240-
var formatted_width = GLOBAL_PRINT_OPTIONS.formatted_width
241-
242235
if x == 0:
243236
if sign:
244237
var result: String = "+0." + "0" * precision + "e+00"
@@ -290,7 +283,12 @@ fn format_floating_scientific[
290283

291284
fn format_floating_precision[
292285
dtype: DType
293-
](value: Scalar[dtype], precision: Int, sign: Bool = False) raises -> String:
286+
](
287+
value: Scalar[dtype],
288+
precision: Int,
289+
sign: Bool = False,
290+
suppress_small: Bool = False,
291+
) raises -> String:
294292
"""
295293
Format a floating-point value to the specified precision.
296294
@@ -299,6 +297,7 @@ fn format_floating_precision[
299297
precision: The number of decimal places to include.
300298
sign: Whether to include the sign of the float in the result.
301299
Defaults to False.
300+
suppress_small: Whether to suppress small numbers. Defaults to False.
302301
303302
Returns:
304303
The formatted value as a string.
@@ -316,7 +315,6 @@ fn format_floating_precision[
316315
if precision < 0:
317316
raise Error("Precision must be a non-negative integer.")
318317

319-
var suppress_small = GLOBAL_PRINT_OPTIONS.suppress_small
320318
if suppress_small and abs(value) < 1e-10:
321319
var result: String = String("0.")
322320
for _ in range(precision):
@@ -349,12 +347,19 @@ fn format_floating_precision[
349347

350348
fn format_floating_precision[
351349
cdtype: CDType, dtype: DType
352-
](value: ComplexSIMD[cdtype, dtype=dtype]) raises -> String:
350+
](
351+
value: ComplexSIMD[cdtype, dtype=dtype],
352+
precision: Int = 4,
353+
sign: Bool = False,
354+
) raises -> String:
353355
"""
354356
Format a complex floating-point value to the specified precision.
355357
356358
Args:
357359
value: The complex value to format.
360+
precision: The number of decimal places to include.
361+
sign: Whether to include the sign of the float in the result.
362+
Defaults to False.
358363
359364
Returns:
360365
The formatted value as a string.
@@ -363,9 +368,6 @@ fn format_floating_precision[
363368
Error: If the complex value cannot be formatted.
364369
"""
365370
try:
366-
var precision = GLOBAL_PRINT_OPTIONS.precision
367-
var sign = GLOBAL_PRINT_OPTIONS.sign
368-
369371
return (
370372
"("
371373
+ format_floating_precision(value.re, precision, sign)

0 commit comments

Comments
 (0)