Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions numojo/core/_math_funcs.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ from memory import UnsafePointer

from numojo.core.traits.backend import Backend
from numojo.core.ndarray import NDArray
from numojo.routines.creation import _0darray

# TODO Add string method to give name

Expand Down Expand Up @@ -143,6 +144,13 @@ struct Vectorized(Backend):
Returns:
A a new NDArray that is NDArray with the function func applied.
"""

# For 0darray (numojo scalar)
# Treat it as a scalar and apply the function
if array.ndim == 0:
var result_array = _0darray(val=func[dtype, 1](array._buf.ptr[]))
return result_array

var result_array: NDArray[dtype] = NDArray[dtype](array.shape)
alias width = simdwidthof[dtype]()

Expand Down Expand Up @@ -186,6 +194,13 @@ struct Vectorized(Backend):
"Shape Mismatch error shapes must match for this function"
)

# For 0darray (numojo scalar)
# Treat it as a scalar and apply the function
if array2.ndim == 0:
return self.math_func_1_array_1_scalar_in_one_array_out[
dtype, func
](array1, array2[])

var result_array: NDArray[dtype] = NDArray[dtype](array1.shape)
alias width = simdwidthof[dtype]()

Expand Down Expand Up @@ -223,6 +238,12 @@ struct Vectorized(Backend):
A a new NDArray that is NDArray with the function func applied.
"""

# For 0darray (numojo scalar)
# Treat it as a scalar and apply the function
if array.ndim == 0:
var result_array = _0darray(val=func[dtype, 1](array[], scalar))
return result_array

var result_array: NDArray[dtype] = NDArray[dtype](array.shape)
alias width = simdwidthof[dtype]()

Expand All @@ -249,6 +270,14 @@ struct Vectorized(Backend):
raise Error(
"Shape Mismatch error shapes must match for this function"
)

# For 0darray (numojo scalar)
# Treat it as a scalar and apply the function
if array2.ndim == 0:
return self.math_func_compare_array_and_scalar[dtype, func](
array1, array2[]
)

var result_array: NDArray[DType.bool] = NDArray[DType.bool](
array1.shape
)
Expand Down Expand Up @@ -279,6 +308,12 @@ struct Vectorized(Backend):
](
self: Self, array1: NDArray[dtype], scalar: SIMD[dtype, 1]
) raises -> NDArray[DType.bool]:
# For 0darray (numojo scalar)
# Treat it as a scalar and apply the function
if array1.ndim == 0:
var result_array = _0darray(val=func[dtype, 1](array1[], scalar))
return result_array

var result_array: NDArray[DType.bool] = NDArray[DType.bool](
array1.shape
)
Expand Down
4 changes: 4 additions & 0 deletions numojo/core/datatypes.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -235,5 +235,9 @@ fn _concise_dtype_str(dtype: DType) -> String:
return "f32"
elif dtype == f64:
return "f64"
elif dtype == boolean:
return "boolean"
elif dtype == isize:
return "isize"
else:
return "Unknown"
14 changes: 8 additions & 6 deletions numojo/core/item.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,10 @@ struct Item(CollectionElement):

if normalized_idx < 0 or normalized_idx >= self.ndim:
raise Error(
String("Index ({}) out of range [{}, {})").format(
index(idx), -self.ndim, self.ndim - 1
)
String(
"Error in `numojo.Item.__getitem__()`: \n"
"Index ({}) out of range [{}, {})\n"
).format(index(idx), -self.ndim, self.ndim - 1)
)

return self._buf[normalized_idx]
Expand All @@ -159,9 +160,10 @@ struct Item(CollectionElement):

if normalized_idx < 0 or normalized_idx >= self.ndim:
raise Error(
String("Index ({}) out of range [{}, {})").format(
index(idx), -self.ndim, self.ndim - 1
)
String(
"Error in `numojo.Item.__getitem__()`: \n"
"Index ({}) out of range [{}, {})\n"
).format(index(idx), -self.ndim, self.ndim - 1)
)

self._buf[normalized_idx] = index(val)
Expand Down
Loading