Skip to content

Commit ba0249e

Browse files
[core][ComplexNDArray] Improve ComplexNDArray methods (#275)
## Pull Request Overview (From Copilot) This PR enhances ComplexNDArray functionality by adding comparison operators, trait methods, statistical/reduction methods, and array manipulation capabilities. It also introduces temporary Int conversions for strides/shape operations and implements SIMD load/store methods for vectorized calculations. ### Key Changes - Added trait implementations (ImplicitlyCopyable, Movable) and conversion methods (__bool__, __int__, __float__) for ComplexNDArray - Implemented magnitude-based comparison operators (__lt__, __le__, __gt__, __ge__) for complex arrays - Added statistical methods (all, any, sum, prod, mean, max, min, argmax, argmin, cumsum, cumprod) and array manipulation methods (flatten, fill, row, col, clip, round, T, diagonal, trace, tolist, resize) - Changed internal buffer types from `UnsafePointer[Int]` to `UnsafePointer[Scalar[DType.int]]` in NDArrayShape, NDArrayStrides, and Item structs - Added SIMD load/store methods (load, store, unsafe_load, unsafe_store) for Item, Shape, and Strides <details> <summary>Show a summary per file</summary> | File | Description | | ---- | ----------- | | numojo/routines/indexing.mojo | Added Int conversions for stride operations in compress function | | numojo/routines/creation.mojo | Removed duplicate import statements | | numojo/core/ndstrides.mojo | Changed buffer type to Scalar[DType.int], updated __setitem__ validation, added SIMD load/store methods | | numojo/core/ndshape.mojo | Changed buffer type to Scalar[DType.int], updated __setitem__ validation, added SIMD load/store methods, modified size_of_array calculation | | numojo/core/ndarray.mojo | Added Int conversions for stride/shape buffer accesses throughout | | numojo/core/item.mojo | Changed buffer type to Scalar[DType.int], removed Item.__init__(idx, shape) constructor and offset() method, added SIMD load/store methods | | numojo/core/complex/complex_simd.mojo | Added ImplicitlyCopyable and Movable traits to ComplexSIMD | | numojo/core/complex/complex_ndarray.mojo | Added comparison operators, conversion methods, power operations, statistical methods, and array manipulation methods; added Int conversions for stride operations | </details> --------- Co-authored-by: ZHU Yuhao 朱宇浩 <[email protected]>
1 parent 3876690 commit ba0249e

File tree

8 files changed

+2087
-380
lines changed

8 files changed

+2087
-380
lines changed

numojo/core/complex/complex_ndarray.mojo

Lines changed: 1188 additions & 4 deletions
Large diffs are not rendered by default.

numojo/core/complex/complex_simd.mojo

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ alias CScalar[cdtype: ComplexDType] = ComplexSIMD[cdtype, width=1]
2727

2828

2929
@register_passable("trivial")
30-
struct ComplexSIMD[cdtype: ComplexDType, width: Int = 1](Stringable, Writable):
30+
struct ComplexSIMD[cdtype: ComplexDType, width: Int = 1](
31+
ImplicitlyCopyable, Movable, Stringable, Writable
32+
):
3133
"""
3234
A SIMD-enabled complex number type that supports vectorized operations.
3335

numojo/core/item.mojo

Lines changed: 105 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@ struct Item(
2929
Specifies the indices of an item of an array.
3030
"""
3131

32-
var _buf: UnsafePointer[Int]
32+
# Aliases
33+
alias _type: DType = DType.int
34+
35+
# Fields
36+
var _buf: UnsafePointer[Scalar[Self._type]]
3337
var ndim: Int
3438

3539
@always_inline("nodebug")
@@ -42,7 +46,7 @@ struct Item(
4246
Args:
4347
args: Initial values.
4448
"""
45-
self._buf = UnsafePointer[Int]().alloc(args.__len__())
49+
self._buf = UnsafePointer[Scalar[Self._type]]().alloc(args.__len__())
4650
self.ndim = args.__len__()
4751
for i in range(args.__len__()):
4852
self._buf[i] = index(args[i])
@@ -58,7 +62,7 @@ struct Item(
5862
args: Initial values.
5963
"""
6064
self.ndim = len(args)
61-
self._buf = UnsafePointer[Int]().alloc(self.ndim)
65+
self._buf = UnsafePointer[Scalar[Self._type]]().alloc(self.ndim)
6266
for i in range(self.ndim):
6367
(self._buf + i).init_pointee_copy(index(args[i]))
6468

@@ -70,7 +74,7 @@ struct Item(
7074
args: Initial values.
7175
"""
7276
self.ndim = len(args)
73-
self._buf = UnsafePointer[Int]().alloc(self.ndim)
77+
self._buf = UnsafePointer[Scalar[Self._type]]().alloc(self.ndim)
7478
for i in range(self.ndim):
7579
(self._buf + i).init_pointee_copy(Int(args[i]))
7680

@@ -110,61 +114,14 @@ struct Item(
110114

111115
if ndim == 0:
112116
self.ndim = 0
113-
self._buf = UnsafePointer[Int]()
117+
self._buf = UnsafePointer[Scalar[Self._type]]()
114118
else:
115119
self.ndim = ndim
116-
self._buf = UnsafePointer[Int]().alloc(ndim)
120+
self._buf = UnsafePointer[Scalar[Self._type]]().alloc(ndim)
117121
if initialized:
118122
for i in range(ndim):
119123
(self._buf + i).init_pointee_copy(0)
120124

121-
fn __init__(out self, idx: Int, shape: NDArrayShape) raises:
122-
"""
123-
Get indices of the i-th item of the array of the given shape.
124-
The item traverse the array in C-order.
125-
126-
Args:
127-
idx: The i-th item of the array.
128-
shape: The strides of the array.
129-
130-
Examples:
131-
132-
The following example demonstrates how to get the indices (coordinates)
133-
of the 123-th item of a 3D array with shape (20, 30, 40).
134-
135-
```console
136-
>>> from numojo.prelude import *
137-
>>> var item = Item(123, Shape(20, 30, 40))
138-
>>> print(item)
139-
Item at index: (0,3,3) Length: 3
140-
```
141-
"""
142-
143-
if (idx < 0) or (idx >= shape.size_of_array()):
144-
raise Error(
145-
IndexError(
146-
message=String(
147-
"Linear index {} out of range [0, {})."
148-
).format(idx, shape.size_of_array()),
149-
suggestion=String(
150-
"Ensure 0 <= idx < total size ({})."
151-
).format(shape.size_of_array()),
152-
location=String(
153-
"Item.__init__(idx: Int, shape: NDArrayShape)"
154-
),
155-
)
156-
)
157-
158-
self.ndim = shape.ndim
159-
self._buf = UnsafePointer[Int]().alloc(self.ndim)
160-
161-
var strides = NDArrayStrides(shape, order="C")
162-
var remainder = idx
163-
164-
for i in range(self.ndim):
165-
(self._buf + i).init_pointee_copy(remainder // strides._buf[i])
166-
remainder %= strides._buf[i]
167-
168125
@always_inline("nodebug")
169126
fn __copyinit__(out self, other: Self):
170127
"""Copy construct the tuple.
@@ -173,7 +130,7 @@ struct Item(
173130
other: The tuple to copy.
174131
"""
175132
self.ndim = other.ndim
176-
self._buf = UnsafePointer[Int]().alloc(self.ndim)
133+
self._buf = UnsafePointer[Scalar[Self._type]]().alloc(self.ndim)
177134
memcpy(self._buf, other._buf, self.ndim)
178135

179136
@always_inline("nodebug")
@@ -232,7 +189,7 @@ struct Item(
232189
)
233190
)
234191
var normalized_idx: Int = self.normalize_index(index_int(idx))
235-
return self._buf[normalized_idx]
192+
return Int(self._buf[normalized_idx])
236193

237194
@always_inline("nodebug")
238195
fn __getitem__(self, slice_index: Slice) raises -> Self:
@@ -458,33 +415,6 @@ struct Item(
458415

459416
return new_item^
460417

461-
fn offset(self, strides: NDArrayStrides) -> Int:
462-
"""
463-
Calculates the offset of the item according to strides.
464-
465-
Args:
466-
strides: The strides of the array.
467-
468-
Returns:
469-
The offset of the item.
470-
471-
Examples:
472-
473-
```mojo
474-
from numojo.prelude import *
475-
var item = Item(1, 2, 3)
476-
var strides = nm.Strides(4, 3, 2)
477-
print(item.offset(strides))
478-
# This prints `16`.
479-
```
480-
.
481-
"""
482-
483-
var offset: Int = 0
484-
for i in range(self.ndim):
485-
offset += self._buf[i] * strides._buf[i]
486-
return offset
487-
488418
# ===-------------------------------------------------------------------===#
489419
# Other private methods
490420
# ===-------------------------------------------------------------------===#
@@ -539,7 +469,7 @@ struct Item(
539469
if axis == self.ndim - 1:
540470
return result^
541471

542-
var value: Int = result._buf[axis]
472+
var value: Scalar[Self._type] = result._buf[axis]
543473
for i in range(axis, result.ndim - 1):
544474
result._buf[i] = result._buf[i + 1]
545475
result._buf[result.ndim - 1] = value
@@ -670,6 +600,98 @@ struct Item(
670600

671601
return (start, step, length)
672602

603+
fn load[width: Int = 1](self, idx: Int) raises -> SIMD[Self._type, width]:
604+
"""
605+
Load a SIMD vector from the Item at the specified index.
606+
607+
Parameters:
608+
width: The width of the SIMD vector.
609+
610+
Args:
611+
idx: The starting index to load from.
612+
613+
Returns:
614+
A SIMD vector containing the loaded values.
615+
616+
Raises:
617+
Error: If the load exceeds the bounds of the Item.
618+
"""
619+
if idx < 0 or idx + width > self.ndim:
620+
raise Error(
621+
IndexError(
622+
message=String(
623+
"Load operation out of bounds: idx={} width={} ndim={}"
624+
).format(idx, width, self.ndim),
625+
suggestion=(
626+
"Ensure that idx and width are within valid range."
627+
),
628+
location="Item.load",
629+
)
630+
)
631+
632+
return self._buf.load[width=width](idx)
633+
634+
fn store[
635+
width: Int = 1
636+
](self, idx: Int, value: SIMD[Self._type, width]) raises:
637+
"""
638+
Store a SIMD vector into the Item at the specified index.
639+
640+
Parameters:
641+
width: The width of the SIMD vector.
642+
643+
Args:
644+
idx: The starting index to store to.
645+
value: The SIMD vector to store.
646+
647+
Raises:
648+
Error: If the store exceeds the bounds of the Item.
649+
"""
650+
if idx < 0 or idx + width > self.ndim:
651+
raise Error(
652+
IndexError(
653+
message=String(
654+
"Store operation out of bounds: idx={} width={} ndim={}"
655+
).format(idx, width, self.ndim),
656+
suggestion=(
657+
"Ensure that idx and width are within valid range."
658+
),
659+
location="Item.store",
660+
)
661+
)
662+
663+
self._buf.store[width=width](idx, value)
664+
665+
fn unsafe_load[width: Int = 1](self, idx: Int) -> SIMD[Self._type, width]:
666+
"""
667+
Unsafely load a SIMD vector from the Item at the specified index.
668+
669+
Parameters:
670+
width: The width of the SIMD vector.
671+
672+
Args:
673+
idx: The starting index to load from.
674+
675+
Returns:
676+
A SIMD vector containing the loaded values.
677+
"""
678+
return self._buf.load[width=width](idx)
679+
680+
fn unsafe_store[
681+
width: Int = 1
682+
](self, idx: Int, value: SIMD[Self._type, width]):
683+
"""
684+
Unsafely store a SIMD vector into the Item at the specified index.
685+
686+
Parameters:
687+
width: The width of the SIMD vector.
688+
689+
Args:
690+
idx: The starting index to store to.
691+
value: The SIMD vector to store.
692+
"""
693+
self._buf.store[width=width](idx, value)
694+
673695

674696
struct _ItemIter[
675697
forward: Bool = True,

numojo/core/ndarray.mojo

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ struct NDArray[dtype: DType = DType.float64](
409409
"""
410410
var index_of_buffer: Int = 0
411411
for i in range(self.ndim):
412-
index_of_buffer += indices[i] * self.strides._buf[i]
412+
index_of_buffer += indices[i] * Int(self.strides._buf[i])
413413
return self._buf.ptr[index_of_buffer]
414414

415415
fn _getitem(self, indices: List[Int]) -> Scalar[dtype]:
@@ -437,7 +437,7 @@ struct NDArray[dtype: DType = DType.float64](
437437
"""
438438
var index_of_buffer: Int = 0
439439
for i in range(self.ndim):
440-
index_of_buffer += indices[i] * self.strides._buf[i]
440+
index_of_buffer += indices[i] * Int(self.strides._buf[i])
441441
return self._buf.ptr[index_of_buffer]
442442

443443
fn __getitem__(self) raises -> SIMD[dtype, 1]:
@@ -636,15 +636,15 @@ struct NDArray[dtype: DType = DType.float64](
636636
for lin in range(total):
637637
var rem = lin
638638
for d in range(out_ndim - 1, -1, -1):
639-
var dim = dst.shape._buf[d]
639+
var dim = Int(dst.shape._buf[d])
640640
coords[d] = rem % dim
641641
rem //= dim
642642
var off = base
643643
for d in range(out_ndim):
644644
off += coords[d] * src.strides._buf[d + 1]
645645
var dst_off = 0
646646
for d in range(out_ndim):
647-
dst_off += coords[d] * dst.strides._buf[d]
647+
dst_off += coords[d] * Int(dst.strides._buf[d])
648648
dst._buf.ptr[dst_off] = src._buf.ptr[off]
649649

650650
fn __getitem__(self, var *slices: Slice) raises -> Self:
@@ -1839,7 +1839,7 @@ struct NDArray[dtype: DType = DType.float64](
18391839
"""
18401840
var index_of_buffer: Int = 0
18411841
for i in range(self.ndim):
1842-
index_of_buffer += indices[i] * self.strides._buf[i]
1842+
index_of_buffer += indices[i] * Int(self.strides._buf[i])
18431843
self._buf.ptr[index_of_buffer] = val
18441844

18451845
fn __setitem__(self, idx: Int, val: Self) raises:
@@ -1964,14 +1964,14 @@ struct NDArray[dtype: DType = DType.float64](
19641964
for lin in range(total):
19651965
var rem = lin
19661966
for d in range(out_ndim - 1, -1, -1):
1967-
var dim = src.shape._buf[d]
1967+
var dim = Int(src.shape._buf[d])
19681968
coords[d] = rem % dim
19691969
rem //= dim
19701970
var dst_off = base
19711971
var src_off = 0
19721972
for d in range(out_ndim):
1973-
var stride_src = src.strides._buf[d]
1974-
var stride_dst = dst.strides._buf[d + 1]
1973+
var stride_src = Int(src.strides._buf[d])
1974+
var stride_dst = Int(dst.strides._buf[d + 1])
19751975
var c = coords[d]
19761976
dst_off += c * stride_dst
19771977
src_off += c * stride_src
@@ -3815,7 +3815,7 @@ struct NDArray[dtype: DType = DType.float64](
38153815
"""
38163816
Returns length of 0-th dimension.
38173817
"""
3818-
return self.shape._buf[0]
3818+
return Int(self.shape._buf[0])
38193819

38203820
fn __iter__(
38213821
self,
@@ -5851,7 +5851,7 @@ struct _NDIter[is_mutable: Bool, //, origin: Origin[is_mutable], dtype: DType](
58515851
(indices._buf + i).init_pointee_copy(
58525852
remainder // self.strides_compatible._buf[i]
58535853
)
5854-
remainder %= self.strides_compatible._buf[i]
5854+
remainder %= Int(self.strides_compatible._buf[i])
58555855
(indices._buf + self.axis).init_pointee_copy(remainder)
58565856

58575857
else:
@@ -5860,7 +5860,7 @@ struct _NDIter[is_mutable: Bool, //, origin: Origin[is_mutable], dtype: DType](
58605860
(indices._buf + i).init_pointee_copy(
58615861
remainder // self.strides_compatible._buf[i]
58625862
)
5863-
remainder %= self.strides_compatible._buf[i]
5863+
remainder %= Int(self.strides_compatible._buf[i])
58645864
(indices._buf + self.axis).init_pointee_copy(remainder)
58655865

58665866
return self.ptr[_get_offset(indices, self.strides)]
@@ -5893,15 +5893,15 @@ struct _NDIter[is_mutable: Bool, //, origin: Origin[is_mutable], dtype: DType](
58935893
(indices._buf + i).init_pointee_copy(
58945894
remainder // self.strides_compatible._buf[i]
58955895
)
5896-
remainder %= self.strides_compatible._buf[i]
5896+
remainder %= Int(self.strides_compatible._buf[i])
58975897
(indices._buf + self.axis).init_pointee_copy(remainder)
58985898
else:
58995899
for i in range(self.ndim - 1, -1, -1):
59005900
if i != self.axis:
59015901
(indices._buf + i).init_pointee_copy(
59025902
remainder // self.strides_compatible._buf[i]
59035903
)
5904-
remainder %= self.strides_compatible._buf[i]
5904+
remainder %= Int(self.strides_compatible._buf[i])
59055905
(indices._buf + self.axis).init_pointee_copy(remainder)
59065906

59075907
return self.ptr[_get_offset(indices, self.strides)]

0 commit comments

Comments
 (0)