Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions numojo/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ from numojo.routines.io import (
from numojo.routines.io import printoptions, set_printoptions

from numojo.routines import linalg
from numojo.routines.linalg.misc import diagonal

from numojo.routines import logic
from numojo.routines.logic import (
Expand Down
2 changes: 1 addition & 1 deletion numojo/core/complex/complex_ndarray.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ struct ComplexNDArray[

Example:
```mojo
import numojo as nm
from numojo.prelude import *
var A = nm.ComplexNDArray[cf32](Shape(2,3,4))
```
"""
Expand Down
97 changes: 52 additions & 45 deletions numojo/core/ndarray.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,9 @@ from numojo.routines.io.formatting import (
PrintOptions,
GLOBAL_PRINT_OPTIONS,
)
import numojo.routines.linalg as linalg
from numojo.routines.linalg.products import matmul
import numojo.routines.logic.comparison as comparison
from numojo.routines.manipulation import reshape, ravel
import numojo.routines.math.arithmetic as arithmetic
import numojo.routines.math.extrema as extrema
from numojo.routines.math.products import prod, cumprod
import numojo.routines.math.rounding as rounding
from numojo.routines.math.sums import sum, cumsum
import numojo.routines.sorting as sorting
from numojo.routines.statistics.averages import mean


struct NDArray[dtype: DType = DType.float64](
Expand Down Expand Up @@ -2399,7 +2391,7 @@ struct NDArray[dtype: DType = DType.float64](
self = self - other

fn __matmul__(self, other: Self) raises -> Self:
return matmul(self, other)
return numojo.linalg.matmul(self, other)

fn __mul__[
OtherDType: DType,
Expand Down Expand Up @@ -3221,7 +3213,7 @@ struct NDArray[dtype: DType = DType.float64](
The indices of the sorted NDArray.
"""

return sorting.argsort(self)
return numojo.sorting.argsort(self)

fn argsort(self, axis: Int) raises -> NDArray[DType.index]:
"""
Expand All @@ -3232,7 +3224,7 @@ struct NDArray[dtype: DType = DType.float64](
The indices of the sorted NDArray.
"""

return sorting.argsort(self, axis=axis)
return numojo.sorting.argsort(self, axis=axis)

fn astype[target: DType](self) raises -> NDArray[target]:
"""
Expand Down Expand Up @@ -3322,7 +3314,7 @@ struct NDArray[dtype: DType = DType.float64](
Returns:
Cumprod of all items of an array.
"""
return cumprod[dtype](self)
return numojo.math.cumprod[dtype](self)

fn cumprod(self, axis: Int) raises -> NDArray[dtype]:
"""
Expand All @@ -3334,7 +3326,7 @@ struct NDArray[dtype: DType = DType.float64](
Returns:
Cumprod of array by axis.
"""
return cumprod[dtype](self, axis=axis)
return numojo.math.cumprod[dtype](self, axis=axis)

fn cumsum(self) raises -> NDArray[dtype]:
"""
Expand All @@ -3344,7 +3336,7 @@ struct NDArray[dtype: DType = DType.float64](
Returns:
Cumsum of all items of an array.
"""
return cumsum[dtype](self)
return numojo.math.cumsum[dtype](self)

fn cumsum(self, axis: Int) raises -> NDArray[dtype]:
"""
Expand All @@ -3356,10 +3348,27 @@ struct NDArray[dtype: DType = DType.float64](
Returns:
Cumsum of array by axis.
"""
return cumsum[dtype](self, axis=axis)
return numojo.math.cumsum[dtype](self, axis=axis)

fn diagonal(self):
pass
fn diagonal[dtype: DType](self, offset: Int = 0) raises -> Self:
"""
Returns specific diagonals.
Currently supports only 2D arrays.

Raises:
Error: If the array is not 2D.
Error: If the offset is beyond the shape of the array.

Parameters:
dtype: Data type of the array.

Args:
offset: Offset of the diagonal from the main diagonal.

Returns:
The diagonal of the NDArray.
"""
return numojo.linalg.diagonal(self, offset=offset)

fn fill(mut self, val: Scalar[dtype]):
"""
Expand Down Expand Up @@ -3495,7 +3504,7 @@ struct NDArray[dtype: DType = DType.float64](
The max value.
"""

return extrema.max(self)
return numojo.math.max(self)

fn max(self, axis: Int) raises -> Self:
"""
Expand All @@ -3510,7 +3519,7 @@ struct NDArray[dtype: DType = DType.float64](
An array with reduced number of dimensions.
"""

return extrema.max(self, axis=axis)
return numojo.math.max(self, axis=axis)

# TODO: Remove this methods
fn mdot(self, other: Self) raises -> Self:
Expand Down Expand Up @@ -3563,7 +3572,7 @@ struct NDArray[dtype: DType = DType.float64](
Returns:
The mean of the array.
"""
return mean[returned_dtype](self)
return numojo.statistics.mean[returned_dtype](self)

fn mean[
returned_dtype: DType = DType.float64
Expand All @@ -3578,7 +3587,7 @@ struct NDArray[dtype: DType = DType.float64](
An NDArray.

"""
return mean[returned_dtype](self, axis)
return numojo.statistics.mean[returned_dtype](self, axis)

fn median[
returned_dtype: DType = DType.float64
Expand Down Expand Up @@ -3615,7 +3624,7 @@ struct NDArray[dtype: DType = DType.float64](
The min value.
"""

return extrema.min(self)
return numojo.math.min(self)

fn min(self, axis: Int) raises -> Self:
"""
Expand All @@ -3630,7 +3639,7 @@ struct NDArray[dtype: DType = DType.float64](
An array with reduced number of dimensions.
"""

return extrema.min(self, axis=axis)
return numojo.math.min(self, axis=axis)

fn nditer(self) raises -> _NDIter[__origin_of(self), dtype]:
"""
Expand Down Expand Up @@ -3705,7 +3714,7 @@ struct NDArray[dtype: DType = DType.float64](
Returns:
Scalar.
"""
return sum(self)
return numojo.math.prod(self)

fn prod(self: Self, axis: Int) raises -> Self:
"""
Expand All @@ -3718,7 +3727,7 @@ struct NDArray[dtype: DType = DType.float64](
An NDArray.
"""

return prod(self, axis=axis)
return numojo.math.prod(self, axis=axis)

# TODO: Remove this methods
fn rdot(self, other: Self) raises -> Self:
Expand Down Expand Up @@ -3770,7 +3779,7 @@ struct NDArray[dtype: DType = DType.float64](
Returns:
Array of the same data with a new shape.
"""
return reshape[dtype](self, shape=shape, order=order)
return numojo.reshape(self, shape=shape, order=order)

fn resize(mut self, shape: NDArrayShape) raises:
"""
Expand Down Expand Up @@ -3829,30 +3838,28 @@ struct NDArray[dtype: DType = DType.float64](
buffer.store(i, self._buf.ptr.load[width=1](i + id * width))
return buffer

fn sort(mut self) raises:
fn sort(mut self, axis: Int = -1) raises:
"""
Sort NDArray using quick sort method.
It is not guaranteed to be unstable.

When no axis is given, the array is flattened before sorting.

Sorts the array in-place along the given axis using quick sort method.
The deault axis is -1.
See `numojo.sorting.sort` for more information.
"""
var I = NDArray[DType.index](self.shape)
self = ravel(self)
sorting._sort_inplace(self, I, axis=0)

fn sort(mut self, owned axis: Int) raises:
Args:
axis: The axis along which the array is sorted. Defaults to -1.
"""
Sort NDArray along the given axis using quick sort method.
It is not guaranteed to be unstable.

When no axis is given, the array is flattened before sorting.
var normalized_axis: Int = axis
if normalized_axis < 0:
normalized_axis += self.ndim
if (normalized_axis >= self.ndim) or (normalized_axis < 0):
raise Error(
String(
"\nError in `NDArray.sort()`: "
"Axis ({}) is not in valid range [-{}, {})."
).format(axis, self.ndim, self.ndim)
)

See `numojo.sorting.sort` for more information.
"""
var I = NDArray[DType.index](self.shape)
sorting._sort_inplace(self, I, axis=axis)
numojo.sorting._sort_inplace(self, axis=normalized_axis)

fn std[
returned_dtype: DType = DType.float64
Expand Down Expand Up @@ -4002,7 +4009,7 @@ struct NDArray[dtype: DType = DType.float64](
Returns:
The trace of the ndarray.
"""
return linalg.norms.trace[dtype](self, offset, axis1, axis2)
return numojo.linalg.trace[dtype](self, offset, axis1, axis2)

# TODO: Remove the underscore in the method name when view is supported.
fn _transpose(self) raises -> Self:
Expand Down
1 change: 1 addition & 0 deletions numojo/routines/linalg/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ from .decompositions import lu_decomposition, qr
from .norms import det, trace
from .products import cross, dot, matmul
from .solving import inv, solve, lstsq
from .misc import diagonal
61 changes: 61 additions & 0 deletions numojo/routines/linalg/misc.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# ===----------------------------------------------------------------------=== #
# Distributed under the Apache 2.0 License with LLVM Exceptions.
# See LICENSE and the LLVM License for more information.
# https://github.com/Mojo-Numerics-and-Algorithms-group/NuMojo/blob/main/LICENSE
# https://llvm.org/LICENSE.txt
# ===----------------------------------------------------------------------=== #

# ===----------------------------------------------------------------------=== #
# Miscellaneous Linear Algebra Routines
# ===----------------------------------------------------------------------=== #

from numojo.core.ndarray import NDArray


fn diagonal[
dtype: DType
](a: NDArray[dtype], offset: Int = 0) raises -> NDArray[dtype]:
"""
Returns specific diagonals.
Currently supports only 2D arrays.

Raises:
Error: If the array is not 2D.
Error: If the offset is beyond the shape of the array.

Parameters:
dtype: Data type of the array.

Args:
a: An NDArray.
offset: Offset of the diagonal from the main diagonal.

Returns:
The diagonal of the NDArray.
"""

if a.ndim != 2:
raise Error("\nError in `diagonal`: Only supports 2D arrays")

var m = a.shape[0]
var n = a.shape[1]

if offset >= max(m, n): # Offset beyond the shape of the array
raise Error(
"\nError in `diagonal`: Offset beyond the shape of the array"
)

var res: NDArray[dtype]

if offset >= 0:
var size_of_res = min(n - offset, m)
res = NDArray[dtype](Shape(size_of_res))
for i in range(size_of_res):
res.item(i) = a.item(i, i + offset)
else:
var size_of_res = min(m + offset, m)
res = NDArray[dtype](Shape(size_of_res))
for i in range(size_of_res):
res.item(i) = a.item(i - offset, i)

return res
45 changes: 45 additions & 0 deletions numojo/routines/sorting.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,51 @@ fn _sort_inplace[
)


fn _sort_inplace[dtype: DType](mut A: NDArray[dtype], axis: Int) raises:
"""
Sort in-place NDArray along the given axis using quick sort method.
It is not guaranteed to be unstable.

Parameters:
dtype: The input element type.

Args:
A: NDArray to sort.
axis: The axis along which the array is sorted.
"""

if (axis >= A.ndim) or (axis < 0):
raise Error(
String(
"\nError in `_sort_inplace()`: "
"Axis ({}) is not in valid range [0, {})."
).format(axis, A.ndim)
)

var array_order = "C" if A.flags.C_CONTIGUOUS else "F"
var continous_axis = A.ndim - 1 if array_order == "C" else A.ndim - 2
"""Contiguously stored axis. -1 if row-major, -2 if col-major."""

if axis == continous_axis: # Last axis
for i in range(A.size // A.shape[continous_axis]):
_sort_in_range(
A,
left=i * A.shape[continous_axis],
right=(i + 1) * A.shape[continous_axis] - 1,
)
else:
var transposed_axes = List[Int](capacity=A.ndim)
for i in range(A.ndim):
transposed_axes.append(i)
transposed_axes[axis], transposed_axes[continous_axis] = (
transposed_axes[continous_axis],
transposed_axes[axis],
)
A = transpose(A, axes=transposed_axes)
_sort_inplace(A, axis=A.ndim - 1)
A = transpose(A, axes=transposed_axes)


fn _sort_inplace[
dtype: DType
](mut A: NDArray[dtype], mut I: NDArray[DType.index], owned axis: Int) raises:
Expand Down
12 changes: 12 additions & 0 deletions tests/routines/test_linalg.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,15 @@ def norms():
check_values_close(
nm.math.linalg.det(arr), np.linalg.det(np_arr), "`det` is broken"
)


def test_misc():
var np = Python.import_module("numpy")
var arr = nm.core.random.rand(4, 8)
var np_arr = arr.to_numpy()
for i in range(-3, 8):
check_is_close(
nm.diagonal(arr, offset=i),
np.diagonal(np_arr, offset=i),
String("`diagonal` by axis {} is broken").format(i),
)
Loading