From 1ea05ef2aa556d3f21dd1eda9a6e912487893970 Mon Sep 17 00:00:00 2001 From: shivasankar Date: Sun, 13 Apr 2025 01:39:30 +0900 Subject: [PATCH 1/4] removed complex dtype --- numojo/core/complex/complex_dtype.mojo | 1119 ------------------------ 1 file changed, 1119 deletions(-) delete mode 100644 numojo/core/complex/complex_dtype.mojo diff --git a/numojo/core/complex/complex_dtype.mojo b/numojo/core/complex/complex_dtype.mojo deleted file mode 100644 index 04abd8cc..00000000 --- a/numojo/core/complex/complex_dtype.mojo +++ /dev/null @@ -1,1119 +0,0 @@ -# Code for CDType is adapted from the Mojo Standard Library -# (https://github.com/modularml/mojo) -# licensed under the Apache License, Version 2.0. -# Modifications were made for the purposes of this project to -# support a Complex SIMD type. - -from collections import KeyElement -from hashlib._hasher import _HashableWithHasher, _Hasher -from sys import bitwidthof, os_is_windows, sizeof -from math import sqrt - -alias _mIsSigned = UInt8(1) -alias _mIsInteger = UInt8(1 << 7) -alias _mIsNotInteger = UInt8(~(1 << 7)) -alias _mIsFloat = UInt8(1 << 6) - - -@value -@register_passable("trivial") -struct CDType( - Stringable, - Writable, - Representable, - KeyElement, - CollectionElementNew, - _HashableWithHasher, -): - """Represents CDType and provides methods for working with it.""" - - alias type = __mlir_type.`!kgen.dtype` - var re_value: Self.type - var im_value: Self.type - - alias invalid = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents an invalid or unknown data type.""" - alias bool = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents a boolean data type.""" - alias int8 = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents a signed integer type whose bitwidth is 8.""" - alias uint8 = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents an unsigned integer type whose bitwidth is 8.""" - alias int16 = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents a signed integer type whose bitwidth is 16.""" - alias uint16 = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents an unsigned integer type whose bitwidth is 16.""" - alias int32 = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents a signed integer type whose bitwidth is 32.""" - alias uint32 = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents an unsigned integer type whose bitwidth is 32.""" - alias int64 = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents a signed integer type whose bitwidth is 64.""" - alias uint64 = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents an unsigned integer type whose bitwidth is 64.""" - alias float8_e5m2 = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents a FP8E5M2 floating point format from the [OFP8 - standard](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1). - - The 8 bits are encoded as `seeeeemm`: - - (s)ign: 1 bit - - (e)xponent: 5 bits - - (m)antissa: 2 bits - - exponent bias: 15 - - nan: {0,1}11111{01,10,11} - - inf: 01111100 - - -inf: 11111100 - - -0: 10000000 - """ - alias float8_e5m2fnuz = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents a FP8E5M2FNUZ floating point format. - - The 8 bits are encoded as `seeeeemm`: - - (s)ign: 1 bit - - (e)xponent: 5 bits - - (m)antissa: 2 bits - - exponent bias: 16 - - nan: 10000000 - - fn: finite (no inf or -inf encodings) - - uz: unsigned zero (no -0 encoding) - """ - alias float8_e4m3 = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents a FP8E4M3 floating point format from the [OFP8 - standard](https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1). - - This type is named `float8_e4m3fn` (the "fn" stands for "finite") in some - frameworks, as it does not encode -inf or inf. - - The 8 bits are encoded as `seeeemmm`: - - (s)ign: 1 bit - - (e)xponent: 4 bits - - (m)antissa: 3 bits - - exponent bias: 7 - - nan: 01111111, 11111111 - - -0: 10000000 - - fn: finite (no inf or -inf encodings) - """ - alias float8_e4m3fnuz = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents a FP8E4M3FNUZ floating point format. - - The 8 bits are encoded as `seeeemmm`: - - (s)ign: 1 bit - - (e)xponent: 4 bits - - (m)antissa: 3 bits - - exponent bias: 8 - - nan: 10000000 - - fn: finite (no inf or -inf encodings) - - uz: unsigned zero (no -0 encoding) - """ - alias bfloat16 = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents a brain floating point value whose bitwidth is 16.""" - alias float16 = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents an IEEE754-2008 `binary16` floating point value.""" - alias float32 = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents an IEEE754-2008 `binary32` floating point value.""" - alias tensor_float32 = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents a special floating point format supported by NVIDIA Tensor - Cores, with the same range as float32 and reduced precision (>=10 bits). - Note that this type is only available on NVIDIA GPUs. - """ - alias float64 = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents an IEEE754-2008 `binary64` floating point value.""" - alias index = CDType( - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - __mlir_attr.`#kgen.dtype.constant : !kgen.dtype`, - ) - """Represents an integral type whose bitwidth is the maximum integral value - on the system.""" - - @always_inline - fn __init__(out self, *, other: Self): - """Copy this CDType. - - Arguments: - other: The CDType to copy. - """ - self = other - - @always_inline - # @implicit - fn __init__(out self, re_value: Self.type, im_value: Self.type): - """Construct a CDType from MLIR dtype. - - Arguments: - value: The MLIR dtype. - """ - self.re_value = re_value - self.im_value = im_value - - @staticmethod - @parameter - fn _from_dtype[dtype: DType]() -> CDType: - """Construct a CDType from a DType. - - Arguments: - dtype: The DType to convert. - """ - - @parameter - if dtype == DType.bool: - return CDType.bool - if dtype == DType.int8: - return CDType.int8 - if dtype == DType.uint8: - return CDType.uint8 - if dtype == DType.int16: - return CDType.int16 - if dtype == DType.uint16: - return CDType.uint16 - if dtype == DType.int32: - return CDType.int32 - if dtype == DType.uint32: - return CDType.uint32 - if dtype == DType.int64: - return CDType.int64 - if dtype == DType.uint64: - return CDType.uint64 - if dtype == DType.index: - return CDType.index - if dtype == DType.float8_e5m2: - return CDType.float8_e5m2 - if dtype == DType.float8_e5m2fnuz: - return CDType.float8_e5m2fnuz - if dtype == DType.float8_e4m3: - return CDType.float8_e4m3 - if dtype == DType.float8_e4m3fnuz: - return CDType.float8_e4m3fnuz - if dtype == DType.bfloat16: - return CDType.bfloat16 - if dtype == DType.float16: - return CDType.float16 - if dtype == DType.float32: - return CDType.float32 - if dtype == DType.tensor_float32: - return CDType.tensor_float32 - if dtype == DType.float64: - return CDType.float64 - if dtype == DType.invalid: - return CDType.invalid - else: - return CDType.invalid - - @staticmethod - fn _from_dtype(dtype: DType) -> CDType: - """Construct a CDType from a DType. - - Arguments: - dtype: The DType to convert. - """ - - if dtype == DType.bool: - return CDType.bool - if dtype == DType.int8: - return CDType.int8 - if dtype == DType.uint8: - return CDType.uint8 - if dtype == DType.int16: - return CDType.int16 - if dtype == DType.uint16: - return CDType.uint16 - if dtype == DType.int32: - return CDType.int32 - if dtype == DType.uint32: - return CDType.uint32 - if dtype == DType.int64: - return CDType.int64 - if dtype == DType.uint64: - return CDType.uint64 - if dtype == DType.index: - return CDType.index - if dtype == DType.float8_e5m2: - return CDType.float8_e5m2 - if dtype == DType.float8_e5m2fnuz: - return CDType.float8_e5m2fnuz - if dtype == DType.float8_e4m3: - return CDType.float8_e4m3 - if dtype == DType.float8_e4m3fnuz: - return CDType.float8_e4m3fnuz - if dtype == DType.bfloat16: - return CDType.bfloat16 - if dtype == DType.float16: - return CDType.float16 - if dtype == DType.float32: - return CDType.float32 - if dtype == DType.tensor_float32: - return CDType.tensor_float32 - if dtype == DType.float64: - return CDType.float64 - if dtype == DType.invalid: - return CDType.invalid - else: - return CDType.invalid - - @staticmethod - fn _from_str(str: String) -> CDType: - """Construct a CDType from a string. - - Arguments: - str: The name of the CDType. - """ - if str.startswith(String("CDType.")): - return Self._from_str(str.removeprefix("CDType.")) - elif str == String("bool"): - return CDType.bool - elif str == String("int8"): - return CDType.int8 - elif str == String("uint8"): - return CDType.uint8 - elif str == String("int16"): - return CDType.int16 - elif str == String("uint16"): - return CDType.uint16 - elif str == String("int32"): - return CDType.int32 - elif str == String("uint32"): - return CDType.uint32 - elif str == String("int64"): - return CDType.int64 - elif str == String("uint64"): - return CDType.uint64 - elif str == String("index"): - return CDType.index - elif str == String("float8_e5m2"): - return CDType.float8_e5m2 - elif str == String("float8_e5m2fnuz"): - return CDType.float8_e5m2fnuz - elif str == String("float8_e4m3"): - return CDType.float8_e4m3 - elif str == String("float8_e4m3fnuz"): - return CDType.float8_e4m3fnuz - elif str == String("bfloat16"): - return CDType.bfloat16 - elif str == String("float16"): - return CDType.float16 - elif str == String("float32"): - return CDType.float32 - elif str == String("float64"): - return CDType.float64 - elif str == String("tensor_float32"): - return CDType.tensor_float32 - elif str == String("invalid"): - return CDType.invalid - else: - return CDType.invalid - - @staticmethod - @parameter - fn to_dtype[other: Self]() -> DType: - """Find the equivalent DType. - - Returns: - True if the DTypes are the same and False otherwise. - """ - - @parameter - if other == CDType.bool: - return DType.bool - if other == CDType.int8: - return DType.int8 - if other == CDType.uint8: - return DType.uint8 - if other == CDType.int16: - return DType.int16 - if other == CDType.uint16: - return DType.uint16 - if other == CDType.int32: - return DType.int32 - if other == CDType.uint32: - return DType.uint32 - if other == CDType.int64: - return DType.int64 - if other == CDType.uint64: - return DType.uint64 - if other == CDType.index: - return DType.index - if other == CDType.float8_e5m2: - return DType.float8_e5m2 - if other == CDType.float8_e5m2fnuz: - return DType.float8_e5m2fnuz - if other == CDType.float8_e4m3: - return DType.float8_e4m3 - if other == CDType.float8_e4m3fnuz: - return DType.float8_e4m3fnuz - if other == CDType.bfloat16: - return DType.bfloat16 - if other == CDType.float16: - return DType.float16 - if other == CDType.float32: - return DType.float32 - if other == CDType.tensor_float32: - return DType.tensor_float32 - if other == CDType.float64: - return DType.float64 - if other == CDType.invalid: - return DType.invalid - else: - return DType.invalid - - @no_inline - fn __str__(self) -> String: - """Gets the name of the CDType. - - Returns: - The name of the dtype. - """ - - return String.write(self) - - @no_inline - fn write_to[W: Writer](self, mut writer: W): - """ - Formats this dtype to the provided Writer. - - Parameters: - W: A type conforming to the Writable trait. - - Arguments: - writer: The object to write to. - """ - - if self == CDType.bool: - return writer.write("cbool") - if self == CDType.int8: - return writer.write("cint8") - if self == CDType.uint8: - return writer.write("cuint8") - if self == CDType.int16: - return writer.write("cint16") - if self == CDType.uint16: - return writer.write("cuint16") - if self == CDType.int32: - return writer.write("cint32") - if self == CDType.uint32: - return writer.write("cuint32") - if self == CDType.int64: - return writer.write("cint64") - if self == CDType.uint64: - return writer.write("cuint64") - if self == CDType.index: - return writer.write("cindex") - if self == CDType.float8_e5m2: - return writer.write("cfloat8_e5m2") - if self == CDType.float8_e5m2fnuz: - return writer.write("cfloat8_e5m2fnuz") - if self == CDType.float8_e4m3: - return writer.write("cfloat8_e4m3") - if self == CDType.float8_e4m3fnuz: - return writer.write("cfloat8_e4m3fnuz") - if self == CDType.bfloat16: - return writer.write("cbfloat16") - if self == CDType.float16: - return writer.write("cfloat16") - if self == CDType.float32: - return writer.write("cfloat32") - if self == CDType.tensor_float32: - return writer.write("ctensor_float32") - if self == CDType.float64: - return writer.write("cfloat64") - if self == CDType.invalid: - return writer.write("cinvalid") - return writer.write("<>") - - @always_inline("nodebug") - fn __repr__(self) -> String: - """Gets the representation of the CDType e.g. `"CDType.float32"`. - - Returns: - The representation of the dtype. - """ - return String.write("CDType.", self) - - @always_inline("nodebug") - fn get_value(self) -> __mlir_type.`!kgen.dtype`: - """Gets the associated internal kgen.dtype value. - - Returns: - The kgen.dtype value. - """ - return self.re_value - - @staticmethod - fn _from_ui8(ui8: __mlir_type.ui8) -> CDType: - return CDType._from_dtype(__mlir_op.`pop.dtype.from_ui8`(ui8)) - - @staticmethod - fn _from_ui8(ui8: __mlir_type.`!pop.scalar`) -> CDType: - return CDType._from_ui8( - __mlir_op.`pop.cast_to_builtin`[_type = __mlir_type.ui8](ui8) - ) - - @always_inline("nodebug") - fn _as_i8( - self, - ) -> __mlir_type.`!pop.scalar`: - var val = __mlir_op.`pop.dtype.to_ui8`(self.re_value) - return __mlir_op.`pop.cast_from_builtin`[ - _type = __mlir_type.`!pop.scalar` - ](val) - - @always_inline("nodebug") - fn __is__(self, rhs: CDType) -> Bool: - """Compares one CDType to another for equality. - - Arguments: - rhs: The CDType to compare against. - - Returns: - True if the DTypes are the same and False otherwise. - """ - return self == rhs - - @always_inline("nodebug") - fn __isnot__(self, rhs: CDType) -> Bool: - """Compares one CDType to another for inequality. - - Arguments: - rhs: The CDType to compare against. - - Returns: - True if the CDTypes are the same and False otherwise. - """ - return self != rhs - - @always_inline("nodebug") - fn __eq__(self, rhs: CDType) -> Bool: - """Compares one CDType to another for equality. - - Arguments: - rhs: The CDType to compare against. - - Returns: - True if the DTypes are the same and False otherwise. - """ - return __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop`]( - self._as_i8(), rhs._as_i8() - ) - - @always_inline("nodebug") - fn __ne__(self, rhs: CDType) -> Bool: - """Compares one CDType to another for inequality. - - Arguments: - rhs: The CDType to compare against. - - Returns: - False if the DTypes are the same and True otherwise. - """ - return __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop`]( - self._as_i8(), rhs._as_i8() - ) - - fn __hash__(self) -> UInt: - """Return a 64-bit hash for this `CDType` value. - - Returns: - A 64-bit integer hash of this `CDType` value. - """ - return hash(UInt8(self._as_i8())) - - fn __hash__[H: _Hasher](self, mut hasher: H): - """Updates hasher with this `CDType` value. - - Parameters: - H: The hasher type. - - Arguments: - hasher: The hasher instance. - """ - hasher._update_with_simd(UInt8(self._as_i8())) - - @always_inline("nodebug") - fn is_unsigned(self) -> Bool: - """Returns True if the type parameter is unsigned and False otherwise. - - Returns: - Returns True if the input type parameter is unsigned. - """ - if not self.is_integral(): - return False - return Bool( - __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop`]( - __mlir_op.`pop.simd.and`(self._as_i8(), _mIsSigned.value), - UInt8(0).value, - ) - ) - - @always_inline("nodebug") - fn is_signed(self) -> Bool: - """Returns True if the type parameter is signed and False otherwise. - - Returns: - Returns True if the input type parameter is signed. - """ - if self is CDType.index or self.is_floating_point(): - return True - if not self.is_integral(): - return False - return Bool( - __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop`]( - __mlir_op.`pop.simd.and`(self._as_i8(), _mIsSigned.value), - UInt8(0).value, - ) - ) - - @always_inline("nodebug") - fn _is_non_index_integral(self) -> Bool: - """Returns True if the type parameter is a non-index integer value and False otherwise. - - Returns: - Returns True if the input type parameter is a non-index integer. - """ - return Bool( - __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop`]( - __mlir_op.`pop.simd.and`(self._as_i8(), _mIsInteger.value), - UInt8(0).value, - ) - ) - - @always_inline("nodebug") - fn is_integral(self) -> Bool: - """Returns True if the type parameter is an integer and False otherwise. - - Returns: - Returns True if the input type parameter is an integer. - """ - if self is CDType.index: - return True - return self._is_non_index_integral() - - @always_inline("nodebug") - fn is_floating_point(self) -> Bool: - """Returns True if the type parameter is a floating-point and False - otherwise. - - Returns: - Returns True if the input type parameter is a floating-point. - """ - if self.is_integral(): - return False - return Bool( - __mlir_op.`pop.cmp`[pred = __mlir_attr.`#pop`]( - __mlir_op.`pop.simd.and`(self._as_i8(), _mIsFloat.value), - UInt8(0).value, - ) - ) - - @always_inline("nodebug") - fn is_float8(self) -> Bool: - """Returns True if the type is a 8bit-precision floating point type, - e.g. float8_e5m2, float8_e5m2fnuz, float8_e4m3 and float8_e4m3fnuz. - - Returns: - True if the type is a 8bit-precision float, false otherwise. - """ - - return self in ( - CDType.float8_e5m2, - CDType.float8_e4m3, - CDType.float8_e5m2fnuz, - CDType.float8_e4m3fnuz, - ) - - @always_inline("nodebug") - fn is_half_float(self) -> Bool: - """Returns True if the type is a half-precision floating point type, - e.g. either fp16 or bf16. - - Returns: - True if the type is a half-precision float, false otherwise.. - """ - - return self in (CDType.bfloat16, CDType.float16) - - @always_inline("nodebug") - fn is_numeric(self) -> Bool: - """Returns True if the type parameter is numeric (i.e. you can perform - arithmetic operations on). - - Returns: - Returns True if the input type parameter is either integral or - floating-point. - """ - return self.is_integral() or self.is_floating_point() - - @always_inline - fn sizeof(self) -> Int: - """Returns the size in bytes of the current CDType. - - Returns: - Returns the size in bytes of the current CDType. - """ - - if self._is_non_index_integral(): - return Int( - UInt8( - __mlir_op.`pop.shl`( - UInt8(1).value, - __mlir_op.`pop.sub`( - __mlir_op.`pop.shr`( - __mlir_op.`pop.simd.and`( - self._as_i8(), _mIsNotInteger.value - ), - UInt8(1).value, - ), - UInt8(3).value, - ), - ) - ) - ) - - if self == CDType.bool: - return 2 * sizeof[DType.bool]() - if self == CDType.index: - return 2 * sizeof[DType.index]() - if self == CDType.float8_e5m2: - return 2 * sizeof[DType.float8_e5m2]() - if self == CDType.float8_e5m2fnuz: - return 2 * sizeof[DType.float8_e5m2fnuz]() - if self == CDType.float8_e4m3: - return 2 * sizeof[DType.float8_e4m3]() - if self == CDType.float8_e4m3fnuz: - return 2 * sizeof[DType.float8_e4m3fnuz]() - if self == CDType.bfloat16: - return 2 * sizeof[DType.bfloat16]() - if self == CDType.float16: - return 2 * sizeof[DType.float16]() - if self == CDType.float32: - return 2 * sizeof[DType.float32]() - if self == CDType.tensor_float32: - return 2 * sizeof[DType.tensor_float32]() - if self == CDType.float64: - return 2 * sizeof[DType.float64]() - return 2 * sizeof[DType.invalid]() - - @always_inline - fn bitwidth(self) -> Int: - """Returns the size in bits of the current CDType. - - Returns: - Returns the size in bits of the current CDType. - """ - return 2 * 8 * self.sizeof() - - # ===-------------------------------------------------------------------===# - # dispatch_integral - # ===-------------------------------------------------------------------===# - - @always_inline - fn dispatch_integral[ - func: fn[type: CDType] () capturing [_] -> None - ](self) raises: - """Dispatches an integral function corresponding to the current CDType. - - Constraints: - CDType must be integral. - - Parameters: - func: A parametrized on dtype function to dispatch. - """ - if self is CDType.uint8: - func[CDType.uint8]() - elif self is CDType.int8: - func[CDType.int8]() - elif self is CDType.uint16: - func[CDType.uint16]() - elif self is CDType.int16: - func[CDType.int16]() - elif self is CDType.uint32: - func[CDType.uint32]() - elif self is CDType.int32: - func[CDType.int32]() - elif self is CDType.uint64: - func[CDType.uint64]() - elif self is CDType.int64: - func[CDType.int64]() - elif self is CDType.index: - func[CDType.index]() - else: - raise Error("only integral types are supported") - - # ===-------------------------------------------------------------------===# - # dispatch_floating - # ===-------------------------------------------------------------------===# - - @always_inline - fn dispatch_floating[ - func: fn[type: CDType] () capturing [_] -> None - ](self) raises: - """Dispatches a floating-point function corresponding to the current CDType. - - Constraints: - CDType must be floating-point or integral. - - Parameters: - func: A parametrized on dtype function to dispatch. - """ - if self is CDType.float16: - func[CDType.float16]() - # TODO(#15473): Enable after extending LLVM support - # elif self is CDType.bfloat16: - # func[CDType.bfloat16]() - elif self is CDType.float32: - func[CDType.float32]() - elif self is CDType.float64: - func[CDType.float64]() - else: - raise Error("only floating point types are supported") - - @always_inline - fn _dispatch_bitwidth[ - func: fn[type: CDType] () capturing [_] -> None, - ](self) raises: - """Dispatches a function corresponding to the current CDType's bitwidth. - This should only be used if func only depends on the bitwidth of the dtype, - and not other properties of the dtype. - - Parameters: - func: A parametrized on dtype function to dispatch. - """ - var bitwidth = self.bitwidth() - if bitwidth == 8: - func[CDType.uint8]() - elif bitwidth == 16: - func[CDType.uint16]() - elif bitwidth == 32: - func[CDType.uint32]() - elif bitwidth == 64: - func[CDType.uint64]() - else: - raise Error( - "bitwidth_dispatch only supports types with bitwidth [8, 16," - " 32, 64]" - ) - return - - @always_inline - fn _dispatch_custom[ - func: fn[type: CDType] () capturing [_] -> None, *dtypes: CDType - ](self) raises: - """Dispatches a function corresponding to current CDType if it matches - any type in the dtypes parameter. - - Parameters: - func: A parametrized on dtype function to dispatch. - dtypes: A list of DTypes on which to do dispatch. - """ - alias dtype_var = VariadicList[CDType](dtypes) - - @parameter - for idx in range(len(dtype_var)): - alias dtype = dtype_var[idx] - if self == dtype: - return func[dtype]() - - raise Error( - "dispatch_custom: dynamic_type does not match any dtype parameters" - ) - - # ===-------------------------------------------------------------------===# - # dispatch_arithmetic - # ===-------------------------------------------------------------------===# - - @always_inline - fn dispatch_arithmetic[ - func: fn[type: CDType] () capturing [_] -> None - ](self) raises: - """Dispatches a function corresponding to the current CDType. - - Parameters: - func: A parametrized on dtype function to dispatch. - """ - if self.is_floating_point(): - self.dispatch_floating[func]() - elif self.is_integral(): - self.dispatch_integral[func]() - else: - raise Error("only arithmetic types are supported") - - -# ===-------------------------------------------------------------------===# -# integral_type_of -# ===-------------------------------------------------------------------===# - - -@always_inline("nodebug") -fn _integral_type_of[type: CDType]() -> CDType: - """Gets the integral type which has the same bitwidth as the input type.""" - - @parameter - if type.is_integral(): - return type - - @parameter - if type.is_float8(): - return CDType.int8 - - @parameter - if type.is_half_float(): - return CDType.int16 - - @parameter - if type is CDType.float32 or type is CDType.tensor_float32: - return CDType.int32 - - @parameter - if type is CDType.float64: - return CDType.int64 - - return type.invalid - - -@always_inline("nodebug") -fn _uint_type_of[type: CDType]() -> CDType: - """Gets the unsigned integral type which has the same bitwidth as the input - type.""" - - @parameter - if type.is_integral() and type.is_unsigned(): - return type - - @parameter - if type.is_float8() or type is CDType.int8: - return CDType.uint8 - - @parameter - if type.is_half_float() or type is CDType.int16: - return CDType.uint16 - - @parameter - if ( - type is CDType.float32 - or type is CDType.tensor_float32 - or type is CDType.int32 - ): - return CDType.uint32 - - @parameter - if type is CDType.float64 or type is CDType.int64: - return CDType.uint64 - - return type.invalid - - -# ===-------------------------------------------------------------------===# -# _unsigned_integral_type_of -# ===-------------------------------------------------------------------===# - - -@always_inline("nodebug") -fn _unsigned_integral_type_of[type: CDType]() -> CDType: - """Gets the unsigned integral type which has the same bitwidth as - the input type.""" - - @parameter - if type.is_integral(): - return _uint_type_of_width[bitwidthof[CDType.to_dtype[type]()]()]() - - @parameter - if type.is_float8(): - return CDType.uint8 - - @parameter - if type.is_half_float(): - return CDType.uint16 - - @parameter - if type is CDType.float32 or type is CDType.tensor_float32: - return CDType.uint32 - - @parameter - if type is CDType.float64: - return CDType.uint64 - - return type.invalid - - -# ===-------------------------------------------------------------------===# -# _scientific_notation_digits -# ===-------------------------------------------------------------------===# - - -fn _scientific_notation_digits[type: CDType]() -> StringLiteral: - """Get the number of digits as a StringLiteral for the scientific notation - representation of a float. - """ - constrained[type.is_floating_point(), "expected floating point type"]() - - @parameter - if type.is_float8(): - return "2" - elif type.is_half_float(): - return "4" - elif type is CDType.float32 or type is CDType.tensor_float32: - return "8" - else: - constrained[type is CDType.float64, "unknown floating point type"]() - return "16" - - -# ===-------------------------------------------------------------------===# -# _int_type_of_width -# ===-------------------------------------------------------------------===# - - -@parameter -@always_inline -fn _int_type_of_width[width: Int]() -> CDType: - constrained[ - width == 8 or width == 16 or width == 32 or width == 64, - "width must be either 8, 16, 32, or 64", - ]() - - @parameter - if width == 8: - return CDType.int8 - elif width == 16: - return CDType.int16 - elif width == 32: - return CDType.int32 - else: - return CDType.int64 - - -# ===-------------------------------------------------------------------===# -# _uint_type_of_width -# ===-------------------------------------------------------------------===# - - -@parameter -@always_inline -fn _uint_type_of_width[width: Int]() -> CDType: - constrained[ - width == 8 or width == 16 or width == 32 or width == 64, - "width must be either 8, 16, 32, or 64", - ]() - - @parameter - if width == 8: - return CDType.uint8 - elif width == 16: - return CDType.uint16 - elif width == 32: - return CDType.uint32 - else: - return CDType.uint64 - - -# ===-------------------------------------------------------------------===# -# printf format -# ===-------------------------------------------------------------------===# - - -@always_inline -fn _index_printf_format() -> StringLiteral: - @parameter - if bitwidthof[Int]() == 32: - return "%d" - elif os_is_windows(): - return "%lld" - else: - return "%ld" - - -@always_inline -fn _get_dtype_printf_format[type: CDType]() -> StringLiteral: - @parameter - if type is CDType.bool: - return _index_printf_format() - elif type is CDType.uint8: - return "%hhu" - elif type is CDType.int8: - return "%hhi" - elif type is CDType.uint16: - return "%hu" - elif type is CDType.int16: - return "%hi" - elif type is CDType.uint32: - return "%u" - elif type is CDType.int32: - return "%i" - elif type is CDType.int64: - - @parameter - if os_is_windows(): - return "%lld" - else: - return "%ld" - elif type is CDType.uint64: - - @parameter - if os_is_windows(): - return "%llu" - else: - return "%lu" - elif type is CDType.index: - return _index_printf_format() - - elif type.is_floating_point(): - return "%.17g" - - else: - constrained[False, "invalid dtype"]() - - return "" From 477e37b6fc6e0792b40c2db633907711e8c6b14b Mon Sep 17 00:00:00 2001 From: shivasankar Date: Tue, 15 Apr 2025 17:55:06 +0900 Subject: [PATCH 2/4] removed cdtype and replace it with dtype --- mojoproject.toml | 2 +- numojo/__init__.mojo | 12 - numojo/core/__init__.mojo | 12 - numojo/core/complex/__init__.mojo | 1 - numojo/core/complex/complex_ndarray.mojo | 443 ++++++++++------ numojo/core/complex/complex_simd.mojo | 76 +-- numojo/core/datatypes.mojo | 26 - numojo/prelude.mojo | 12 - numojo/routines/creation.mojo | 638 +++++++++++------------ numojo/routines/io/formatting.mojo | 8 +- tests/core/test_complexArray.mojo | 96 ++-- tests/core/test_complexSIMD.mojo | 20 +- 12 files changed, 685 insertions(+), 661 deletions(-) diff --git a/mojoproject.toml b/mojoproject.toml index 996322da..35c31697 100644 --- a/mojoproject.toml +++ b/mojoproject.toml @@ -51,7 +51,7 @@ doc_pages = "mojo doc numojo/ -o docs.json" release = "clear && magic run final && magic run doc_pages" [dependencies] -max = "=25.1.1" +max = "=25.2" python = ">=3.11" numpy = ">=1.19" scipy = ">=1.14" \ No newline at end of file diff --git a/numojo/__init__.mojo b/numojo/__init__.mojo index 87e82680..08a7a7c1 100644 --- a/numojo/__init__.mojo +++ b/numojo/__init__.mojo @@ -13,7 +13,6 @@ from numojo.core.ndarray import NDArray from numojo.core.ndshape import NDArrayShape, Shape from numojo.core.ndstrides import NDArrayStrides, Strides from numojo.core.item import Item, item -from numojo.core.complex.complex_dtype import CDType from numojo.core.complex.complex_simd import ComplexSIMD, ComplexScalar from numojo.core.complex.complex_ndarray import ComplexNDArray from numojo.core.matrix import Matrix @@ -30,17 +29,6 @@ from numojo.core.datatypes import ( f16, f32, f64, - ci8, - ci16, - ci32, - ci64, - cu8, - cu16, - cu32, - cu64, - cf16, - cf32, - cf64, ) # ===----------------------------------------------------------------------=== # diff --git a/numojo/core/__init__.mojo b/numojo/core/__init__.mojo index 8515f14b..855b8029 100644 --- a/numojo/core/__init__.mojo +++ b/numojo/core/__init__.mojo @@ -7,7 +7,6 @@ from .ndshape import NDArrayShape from .ndstrides import NDArrayStrides from .complex import ( - CDType, ComplexSIMD, ComplexScalar, ComplexNDArray, @@ -25,17 +24,6 @@ from .datatypes import ( f16, f32, f64, - ci8, - ci16, - ci32, - ci64, - cu8, - cu16, - cu32, - cu64, - cf16, - cf32, - cf64, ) # from .utility import diff --git a/numojo/core/complex/__init__.mojo b/numojo/core/complex/__init__.mojo index 53cf6df7..5df2a495 100644 --- a/numojo/core/complex/__init__.mojo +++ b/numojo/core/complex/__init__.mojo @@ -1,3 +1,2 @@ -from .complex_dtype import CDType from .complex_simd import ComplexSIMD, ComplexScalar from .complex_ndarray import ComplexNDArray diff --git a/numojo/core/complex/complex_ndarray.mojo b/numojo/core/complex/complex_ndarray.mojo index 7e25b717..494324cf 100644 --- a/numojo/core/complex/complex_ndarray.mojo +++ b/numojo/core/complex/complex_ndarray.mojo @@ -85,21 +85,21 @@ from numojo.routines.statistics.averages import mean # ===----------------------------------------------------------------------===# # ComplexNDArray # ===----------------------------------------------------------------------===# +# TODO: Add SIMD width as a parameter. @value struct ComplexNDArray[ - cdtype: CDType, *, dtype: DType = CDType.to_dtype[cdtype]() + dtype: DType = DType.float64 ](Stringable, Representable, CollectionElement, Sized, Writable): """ Represents a Complex N-Dimensional Array. Parameters: - cdtype: Complex data type. - dtype: Real data type. + dtype: Complex data type. """ """FIELDS""" - var _re: NDArray[dtype] - var _im: NDArray[dtype] + var _re: NDArray[Self.dtype] + var _im: NDArray[Self.dtype] # It's redundant, but better to have it as fields. var ndim: Int @@ -116,7 +116,7 @@ struct ComplexNDArray[ """LIFETIME METHODS""" @always_inline("nodebug") - fn __init__(mut self, owned re: NDArray[dtype], owned im: NDArray[dtype]): + fn __init__(out self, owned re: NDArray[Self.dtype], owned im: NDArray[Self.dtype]): self._re = re self._im = im self.ndim = re.ndim @@ -127,7 +127,7 @@ struct ComplexNDArray[ @always_inline("nodebug") fn __init__( - mut self, + out self, shape: NDArrayShape, order: String = "C", ) raises: @@ -143,11 +143,11 @@ struct ComplexNDArray[ Example: ```mojo from numojo.prelude import * - var A = nm.ComplexNDArray[cf32](Shape(2,3,4)) + var A = nm.ComplexNDArray[f32](Shape(2,3,4)) ``` """ - self._re = NDArray[dtype](shape, order) - self._im = NDArray[dtype](shape, order) + self._re = NDArray[Self.dtype](shape, order) + self._im = NDArray[Self.dtype](shape, order) self.ndim = self._re.ndim self.shape = self._re.shape self.size = self._re.size @@ -156,7 +156,7 @@ struct ComplexNDArray[ @always_inline("nodebug") fn __init__( - mut self, + out self, shape: List[Int], order: String = "C", ) raises: @@ -167,8 +167,8 @@ struct ComplexNDArray[ shape: List of shape. order: Memory order C or F. """ - self._re = NDArray[dtype](shape, order) - self._im = NDArray[dtype](shape, order) + self._re = NDArray[Self.dtype](shape, order) + self._im = NDArray[Self.dtype](shape, order) self.ndim = self._re.ndim self.shape = self._re.shape self.size = self._re.size @@ -177,7 +177,7 @@ struct ComplexNDArray[ @always_inline("nodebug") fn __init__( - mut self, + out self, shape: VariadicList[Int], order: String = "C", ) raises: @@ -188,8 +188,8 @@ struct ComplexNDArray[ shape: Variadic List of shape. order: Memory order C or F. """ - self._re = NDArray[dtype](shape, order) - self._im = NDArray[dtype](shape, order) + self._re = NDArray[Self.dtype](shape, order) + self._im = NDArray[Self.dtype](shape, order) self.ndim = self._re.ndim self.shape = self._re.shape self.size = self._re.size @@ -197,7 +197,7 @@ struct ComplexNDArray[ self.flags = self._re.flags fn __init__( - mut self, + out self, shape: List[Int], offset: Int, strides: List[Int], @@ -205,8 +205,8 @@ struct ComplexNDArray[ """ Extremely specific ComplexNDArray initializer. """ - self._re = NDArray[dtype](shape, offset, strides) - self._im = NDArray[dtype](shape, offset, strides) + self._re = NDArray[Self.dtype](shape, offset, strides) + self._im = NDArray[Self.dtype](shape, offset, strides) self.ndim = self._re.ndim self.shape = self._re.shape self.size = self._re.size @@ -239,14 +239,14 @@ struct ComplexNDArray[ self.ndim = ndim self.size = size self.flags = flags - self._re = NDArray[dtype](shape, strides, ndim, size, flags) - self._im = NDArray[dtype](shape, strides, ndim, size, flags) + self._re = NDArray[Self.dtype](shape, strides, ndim, size, flags) + self._im = NDArray[Self.dtype](shape, strides, ndim, size, flags) fn __init__( mut self, shape: NDArrayShape, - ref buffer_re: UnsafePointer[Scalar[dtype]], - ref buffer_im: UnsafePointer[Scalar[dtype]], + ref buffer_re: UnsafePointer[Scalar[Self.dtype]], + ref buffer_im: UnsafePointer[Scalar[Self.dtype]], offset: Int, strides: NDArrayStrides, ) raises: @@ -262,7 +262,7 @@ struct ComplexNDArray[ strides: Strides of the array. """ self._re = NDArray(shape, buffer_re, offset, strides) - self._im = NDArray(shape, buffer_re, offset, strides) + self._im = NDArray(shape, buffer_im, offset, strides) self.ndim = self._re.ndim self.shape = self._re.shape self.size = self._re.size @@ -344,7 +344,7 @@ struct ComplexNDArray[ # fn load[width: Int](self, *indices: Int) raises -> SIMD[dtype, width] # Load SIMD at coordinates # ===-------------------------------------------------------------------===# - fn _getitem(self, *indices: Int) -> ComplexSIMD[cdtype, dtype=dtype]: + fn _getitem(self, *indices: Int) -> ComplexSIMD[Self.dtype]: """ Get item at indices and bypass all boundary checks. ***UNSAFE!*** No boundary checks made, for internal use only. @@ -362,19 +362,19 @@ struct ComplexNDArray[ ```mojo import numojo as nm - var A = nm.ones[nm.cf32](nm.Shape(2,3,4)) + var A = nm.ones[nm.f32](nm.Shape(2,3,4)) print(A._getitem(1,2,3)) ``` """ var index_of_buffer: Int = 0 for i in range(self.ndim): index_of_buffer += indices[i] * self.strides._buf[i] - return ComplexSIMD[cdtype, dtype=dtype]( + return ComplexSIMD[Self.dtype]( re=self._re._buf.ptr.load[width=1](index_of_buffer), im=self._im._buf.ptr.load[width=1](index_of_buffer), ) - fn __getitem__(self) raises -> ComplexSIMD[cdtype, dtype=dtype, size=1]: + fn __getitem__(self) raises -> ComplexSIMD[Self.dtype]: """ Gets the value of the 0-D Complex array. @@ -388,7 +388,7 @@ struct ComplexNDArray[ ```console >>> import numojo as nm - >>> var A = nm.ones[nm.cf32](nm.Shape(2,3,4)) + >>> var A = nm.ones[nm.f32](nm.Shape(2,3,4)) >>> print(A[]) # gets values of the 0-D array. ```. """ @@ -397,14 +397,14 @@ struct ComplexNDArray[ "\nError in `numojo.ComplexNDArray.__getitem__()`: " "Cannot get value without index." ) - return ComplexSIMD[cdtype, dtype=dtype]( + return ComplexSIMD[Self.dtype]( re=self._re._buf.ptr[], im=self._im._buf.ptr[], ) fn __getitem__( self, index: Item - ) raises -> ComplexSIMD[cdtype, dtype=dtype, size=1]: + ) raises -> ComplexSIMD[Self.dtype]: """ Get the value at the index list. @@ -422,7 +422,7 @@ struct ComplexNDArray[ ```console >>>import numojo as nm - >>>var A = nm.full[nm.cf32](nm.Shape(2, 5), ComplexSIMD[nm.cf32](1.0, 1.0)) + >>>var A = nm.full[nm.f32](nm.Shape(2, 5), ComplexSIMD[nm.f32](1.0, 1.0)) >>>print(A[Item(1, 2)]) # gets values of the element at (1, 2). ```. """ @@ -446,7 +446,7 @@ struct ComplexNDArray[ ) var idx: Int = _get_offset(index, self.strides) - return ComplexSIMD[cdtype, dtype=dtype]( + return ComplexSIMD[Self.dtype]( re=self._re._buf.ptr.load[width=1](idx), im=self._im._buf.ptr.load[width=1](idx), ) @@ -468,7 +468,7 @@ struct ComplexNDArray[ ```console >>>import numojo as nm - >>>var a = nm.full[nm.cf32](nm.Shape(2, 5), ComplexSIMD[nm.cf32](1.0, 1.0)) + >>>var a = nm.full[nm.f32](nm.Shape(2, 5), ComplexSIMD[nm.f32](1.0, 1.0)) >>>print(a[1]) # returns the second row of the array. ```. """ @@ -484,8 +484,8 @@ struct ComplexNDArray[ var narr: Self if self.ndim == 1: - narr = creation._0darray[cdtype, dtype=dtype]( - ComplexSIMD[cdtype, dtype=dtype]( + narr = creation._0darray[Self.dtype]( + ComplexSIMD[Self.dtype]( re=self._re._buf.ptr[idx], im=self._im._buf.ptr[idx], ), @@ -513,7 +513,7 @@ struct ComplexNDArray[ ```console >>>import numojo as nm - >>>var a = nm.full[nm.cf32](nm.Shape(2, 5), ComplexSIMD[nm.cf32](1.0, 1.0)) + >>>var a = nm.full[nm.f32](nm.Shape(2, 5), ComplexSIMD[nm.f32](1.0, 1.0)) >>>var b = a[:, 2:4] >>>print(b) # `arr[:, 2:4]` returns the corresponding sliced array (2 x 2). ```. @@ -548,7 +548,7 @@ struct ComplexNDArray[ ```console >>>import numojo as nm - >>>var a = nm.full[nm.cf32](nm.Shape(2, 5), ComplexSIMD[nm.cf32](1.0, 1.0)) + >>>var a = nm.full[nm.f32](nm.Shape(2, 5), ComplexSIMD[nm.f32](1.0, 1.0)) >>>var b = a[List[Slice](Slice(0, 2, 1), Slice(2, 4, 1))] # `arr[:, 2:4]` returns the corresponding sliced array (2 x 2). >>>print(b) ```. @@ -615,7 +615,7 @@ struct ComplexNDArray[ temp_stride *= nshape[i] # Create and iteratively set values in the new array - var narr = ComplexNDArray[cdtype, dtype=dtype]( + var narr = ComplexNDArray[Self.dtype]( offset=noffset, shape=nshape, strides=nstrides ) var index_re = List[Int]() @@ -667,7 +667,7 @@ struct ComplexNDArray[ ```console >>>import numojo as nm - >>>var a = nm.full[nm.cf32](nm.Shape(2, 5), ComplexSIMD[nm.cf32](1.0, 1.0)) + >>>var a = nm.full[nm.f32](nm.Shape(2, 5), ComplexSIMD[nm.f32](1.0, 1.0)) >>>var b = a[1, 2:4] >>>print(b) ```. @@ -699,8 +699,8 @@ struct ComplexNDArray[ var narr: Self if count_int == self.ndim: - narr = creation._0darray[cdtype, dtype=dtype]( - ComplexSIMD[cdtype, dtype=dtype]( + narr = creation._0darray[Self.dtype]( + ComplexSIMD[Self.dtype]( re=self._re._buf.ptr[], im=self._im._buf.ptr[], ), @@ -729,8 +729,8 @@ struct ComplexNDArray[ # Get the shape of resulted array var shape = indices.shape.join(self.shape._pop(0)) - var result: ComplexNDArray[cdtype, dtype=dtype] = ComplexNDArray[ - cdtype, dtype=dtype + var result: ComplexNDArray[Self.dtype] = ComplexNDArray[ + Self.dtype ](shape) var size_per_item = self.size // self.shape[0] @@ -810,7 +810,7 @@ struct ComplexNDArray[ len_of_result += 1 # Change the first number of the ndshape - var result = ComplexNDArray[cdtype, dtype=dtype]( + var result = ComplexNDArray[Self.dtype]( shape=NDArrayShape(len_of_result) ) @@ -860,7 +860,7 @@ struct ComplexNDArray[ var shape = self.shape shape._buf[0] = len_of_result - var result = ComplexNDArray[cdtype, dtype=dtype](shape) + var result = ComplexNDArray[Self.dtype](shape) var size_per_item = self.size // self.shape[0] # Fill in the values @@ -901,7 +901,7 @@ struct ComplexNDArray[ return self[mask_array] - fn item(self, owned index: Int) raises -> ComplexSIMD[cdtype, dtype=dtype]: + fn item(self, owned index: Int) raises -> ComplexSIMD[Self.dtype]: """ Return the scalar at the coordinates. If one index is given, get the i-th item of the complex array (not buffer). @@ -924,7 +924,7 @@ struct ComplexNDArray[ ```console >>> import numojo as nm - >>> var A = nm.full[nm.cf32](Shape(2, 2, 2), ComplexSIMD[nm.cf32](1.0, 1.0)) + >>> var A = nm.full[nm.f32](Shape(2, 2, 2), ComplexSIMD[nm.f32](1.0, 1.0)) >>> print(A.item(10)) # returns the 10-th item of the complex array. ```. """ @@ -950,7 +950,7 @@ struct ComplexNDArray[ ) if self.flags.F_CONTIGUOUS: - return ComplexSIMD[cdtype, dtype=dtype]( + return ComplexSIMD[Self.dtype]( re=( self._re._buf.ptr + _transfer_offset(index, self.strides) )[], @@ -960,12 +960,12 @@ struct ComplexNDArray[ ) else: - return ComplexSIMD[cdtype, dtype=dtype]( + return ComplexSIMD[Self.dtype]( re=(self._re._buf.ptr + index)[], im=(self._im._buf.ptr + index)[], ) - fn item(self, *index: Int) raises -> ComplexSIMD[cdtype, dtype=dtype]: + fn item(self, *index: Int) raises -> ComplexSIMD[Self.dtype]: """ Return the scalar at the coordinates. If one index is given, get the i-th item of the complex array (not buffer). @@ -988,7 +988,7 @@ struct ComplexNDArray[ ```console >>> import numojo as nm - >>> var A = nm.full[nm.cf32](Shape(2, 2, 2), ComplexSIMD[nm.cf32](1.0, 1.0)) + >>> var A = nm.full[nm.f32](Shape(2, 2, 2), ComplexSIMD[nm.f32](1.0, 1.0)) >>> print(A.item(1, 1, 1)) # returns the 10-th item of the complex array. ```. """ @@ -1002,7 +1002,7 @@ struct ComplexNDArray[ ) if self.ndim == 0: - return ComplexSIMD[cdtype, dtype=dtype]( + return ComplexSIMD[Self.dtype]( re=self._re._buf.ptr[], im=self._im._buf.ptr[], ) @@ -1019,12 +1019,12 @@ struct ComplexNDArray[ i, self.shape[i] ) ) - return ComplexSIMD[cdtype, dtype=dtype]( + return ComplexSIMD[Self.dtype]( re=(self._re._buf.ptr + _get_offset(index, self.strides))[], im=(self._im._buf.ptr + _get_offset(index, self.strides))[], ) - fn load(self, owned index: Int) raises -> ComplexSIMD[cdtype, dtype=dtype]: + fn load(self, owned index: Int) raises -> ComplexSIMD[Self.dtype]: """ Safely retrieve i-th item from the underlying buffer. @@ -1043,7 +1043,7 @@ struct ComplexNDArray[ ```console >>> import numojo as nm - >>> var A = nm.full[nm.cf32](Shape(2, 2, 2), ComplexSIMD[nm.cf32](1.0, 1.0)) + >>> var A = nm.full[nm.f32](Shape(2, 2, 2), ComplexSIMD[nm.f32](1.0, 1.0)) >>> print(A.load(10)) # returns the 10-th item of the complex array. ```. """ @@ -1059,14 +1059,14 @@ struct ComplexNDArray[ ).format(self.size) ) - return ComplexSIMD[cdtype, dtype=dtype]( + return ComplexSIMD[Self.dtype]( re=self._re._buf.ptr[index], im=self._im._buf.ptr[index], ) fn load[ width: Int = 1 - ](self, index: Int) raises -> ComplexSIMD[cdtype, dtype=dtype]: + ](self, index: Int) raises -> ComplexSIMD[Self.dtype]: """ Safely loads a ComplexSIMD element of size `width` at `index` from the underlying buffer. @@ -1092,7 +1092,7 @@ struct ComplexNDArray[ ).format(self.size) ) - return ComplexSIMD[cdtype, dtype=dtype]( + return ComplexSIMD[Self.dtype]( re=self._re._buf.ptr.load[width=1](index), im=self._im._buf.ptr.load[width=1](index), ) @@ -1100,7 +1100,7 @@ struct ComplexNDArray[ fn load[ width: Int = 1 ](self, *indices: Int) raises -> ComplexSIMD[ - cdtype, dtype=dtype, size=width + Self.dtype, width=width ]: """ Safely loads a ComplexSIMD element of size `width` at given variadic indices @@ -1122,7 +1122,7 @@ struct ComplexNDArray[ ```console >>> import numojo as nm - >>> var A = nm.full[nm.cf32](Shape(2, 2, 2), ComplexSIMD[nm.cf32](1.0, 1.0)) + >>> var A = nm.full[nm.f32](Shape(2, 2, 2), ComplexSIMD[nm.f32](1.0, 1.0)) >>> print(A.load(0, 1, 1)) ```. """ @@ -1147,7 +1147,7 @@ struct ComplexNDArray[ ) var idx: Int = _get_offset(indices, self.strides) - return ComplexSIMD[cdtype, dtype=dtype, size=width]( + return ComplexSIMD[Self.dtype, width=width]( re=self._re._buf.ptr.load[width=width](idx), im=self._im._buf.ptr.load[width=width](idx), ) @@ -1196,7 +1196,7 @@ struct ComplexNDArray[ return slices^ - fn _setitem(self, *indices: Int, val: ComplexSIMD[cdtype, dtype=dtype]): + fn _setitem(self, *indices: Int, val: ComplexSIMD[Self.dtype]): """ (UNSAFE! for internal use only.) Get item at indices and bypass all boundary checks. @@ -1326,7 +1326,7 @@ struct ComplexNDArray[ ) fn __setitem__( - mut self, index: Item, val: ComplexSIMD[cdtype, dtype=dtype] + mut self, index: Item, val: ComplexSIMD[Self.dtype] ) raises: """ Set the value at the index list. @@ -1354,8 +1354,8 @@ struct ComplexNDArray[ fn __setitem__( mut self, - mask: ComplexNDArray[cdtype, dtype=dtype], - value: ComplexSIMD[cdtype, dtype=dtype], + mask: ComplexNDArray[Self.dtype], + value: ComplexSIMD[Self.dtype], ) raises: """ Set the value of the array at the indices where the mask is true. @@ -1533,8 +1533,8 @@ struct ComplexNDArray[ fn __setitem__( mut self, - mask: ComplexNDArray[cdtype, dtype=dtype], - val: ComplexNDArray[cdtype, dtype=dtype], + mask: ComplexNDArray[Self.dtype], + val: ComplexNDArray[Self.dtype], ) raises: """ Set the value of the ComplexNDArray at the indices where the mask is true. @@ -1575,7 +1575,7 @@ struct ComplexNDArray[ "complex_ndarray:ComplexNDArray:__neg__: neg does not accept" " bool type arrays" ) - return self * ComplexSIMD[cdtype, dtype=dtype](-1.0, -1.0) + return self * ComplexSIMD[Self.dtype](-1.0, -1.0) @always_inline("nodebug") fn __eq__(self, other: Self) raises -> NDArray[DType.bool]: @@ -1588,7 +1588,7 @@ struct ComplexNDArray[ @always_inline("nodebug") fn __eq__( - self, other: ComplexSIMD[cdtype, dtype=dtype] + self, other: ComplexSIMD[Self.dtype] ) raises -> NDArray[DType.bool]: """ Itemwise equivalence between scalar and ComplexNDArray. @@ -1608,7 +1608,7 @@ struct ComplexNDArray[ @always_inline("nodebug") fn __ne__( - self, other: ComplexSIMD[cdtype, dtype=dtype] + self, other: ComplexSIMD[Self.dtype] ) raises -> NDArray[DType.bool]: """ Itemwise non-equivalence between scalar and ComplexNDArray. @@ -1619,7 +1619,7 @@ struct ComplexNDArray[ """ ARITHMETIC OPERATIONS """ - fn __add__(self, other: ComplexSIMD[cdtype, dtype=dtype]) raises -> Self: + fn __add__(self, other: ComplexSIMD[Self.dtype]) raises -> Self: """ Enables `ComplexNDArray + ComplexSIMD`. """ @@ -1653,7 +1653,7 @@ struct ComplexNDArray[ return Self(real, imag) fn __radd__( - mut self, other: ComplexSIMD[cdtype, dtype=dtype] + mut self, other: ComplexSIMD[Self.dtype] ) raises -> Self: """ Enables `ComplexSIMD + ComplexNDArray`. @@ -1682,7 +1682,7 @@ struct ComplexNDArray[ var imag: NDArray[dtype] = math.add[dtype](self._im, other) return Self(real, imag) - fn __iadd__(mut self, other: ComplexSIMD[cdtype, dtype=dtype]) raises: + fn __iadd__(mut self, other: ComplexSIMD[Self.dtype]) raises: """ Enables `ComplexNDArray += ComplexSIMD`. """ @@ -1710,7 +1710,7 @@ struct ComplexNDArray[ self._re += other self._im += other - fn __sub__(self, other: ComplexSIMD[cdtype, dtype=dtype]) raises -> Self: + fn __sub__(self, other: ComplexSIMD[Self.dtype]) raises -> Self: """ Enables `ComplexNDArray - ComplexSIMD`. """ @@ -1747,7 +1747,7 @@ struct ComplexNDArray[ return Self(real, imag) fn __rsub__( - mut self, other: ComplexSIMD[cdtype, dtype=dtype] + mut self, other: ComplexSIMD[Self.dtype] ) raises -> Self: """ Enables `ComplexSIMD - ComplexNDArray`. @@ -1772,7 +1772,7 @@ struct ComplexNDArray[ var imag: NDArray[dtype] = math.sub[dtype](other, self._im) return Self(real, imag) - fn __isub__(mut self, other: ComplexSIMD[cdtype, dtype=dtype]) raises: + fn __isub__(mut self, other: ComplexSIMD[Self.dtype]) raises: """ Enables `ComplexNDArray -= ComplexSIMD`. """ @@ -1807,7 +1807,7 @@ struct ComplexNDArray[ var im_re: NDArray[dtype] = linalg.matmul[dtype](self._im, other._re) return Self(re_re - im_im, re_im + im_re) - fn __mul__(self, other: ComplexSIMD[cdtype, dtype=dtype]) raises -> Self: + fn __mul__(self, other: ComplexSIMD[Self.dtype]) raises -> Self: """ Enables `ComplexNDArray * ComplexSIMD`. """ @@ -1843,7 +1843,7 @@ struct ComplexNDArray[ var imag: NDArray[dtype] = math.mul[dtype](self._im, other) return Self(real, imag) - fn __rmul__(self, other: ComplexSIMD[cdtype, dtype=dtype]) raises -> Self: + fn __rmul__(self, other: ComplexSIMD[Self.dtype]) raises -> Self: """ Enables `ComplexSIMD * ComplexNDArray`. """ @@ -1867,7 +1867,7 @@ struct ComplexNDArray[ var imag: NDArray[dtype] = math.mul[dtype](self._im, other) return Self(real, imag) - fn __imul__(mut self, other: ComplexSIMD[cdtype, dtype=dtype]) raises: + fn __imul__(mut self, other: ComplexSIMD[Self.dtype]) raises: """ Enables `ComplexNDArray *= ComplexSIMD`. """ @@ -1896,7 +1896,7 @@ struct ComplexNDArray[ self._im *= other fn __truediv__( - self, other: ComplexSIMD[cdtype, dtype=dtype] + self, other: ComplexSIMD[Self.dtype] ) raises -> Self: """ Enables `ComplexNDArray / ComplexSIMD`. @@ -1914,7 +1914,7 @@ struct ComplexNDArray[ return Self(real, imag) fn __truediv__( - self, other: ComplexNDArray[cdtype, dtype=dtype] + self, other: ComplexNDArray[Self.dtype] ) raises -> Self: """ Enables `ComplexNDArray / ComplexNDArray`. @@ -1934,7 +1934,7 @@ struct ComplexNDArray[ return Self(real, imag) fn __rtruediv__( - mut self, other: ComplexSIMD[cdtype, dtype=dtype] + mut self, other: ComplexSIMD[Self.dtype] ) raises -> Self: """ Enables `ComplexSIMD / ComplexNDArray`. @@ -1965,7 +1965,7 @@ struct ComplexNDArray[ var imag = numer._im / denom._re return Self(real, imag) - fn __itruediv__(mut self, other: ComplexSIMD[cdtype, dtype=dtype]) raises: + fn __itruediv__(mut self, other: ComplexSIMD[Self.dtype]) raises: """ Enables `ComplexNDArray /= ComplexSIMD`. """ @@ -2036,17 +2036,17 @@ struct ComplexNDArray[ An example is: ``` fn main() raises: - var A = ComplexNDArray[cf32](List[ComplexSIMD[cf32]](14,97,-59,-4,112,), shape=List[Int](5,)) + var A = ComplexNDArray[f32](List[ComplexSIMD[f32]](14,97,-59,-4,112,), shape=List[Int](5,)) print(repr(A)) ``` It prints what can be used to construct the array itself: ```console - ComplexNDArray[cf32](List[ComplexSIMD[cf32]](14,97,-59,-4,112,), shape=List[Int](5,)) + ComplexNDArray[f32](List[ComplexSIMD[f32]](14,97,-59,-4,112,), shape=List[Int](5,)) ```. """ try: var result: String = String("ComplexNDArray[CDType.") + String( - self.cdtype + self.dtype ) + String("](List[ComplexSIMD[CDType.c") + String( self._re.dtype ) + String( @@ -2186,7 +2186,7 @@ struct ComplexNDArray[ fn store[ width: Int = 1 - ](mut self, index: Int, val: ComplexSIMD[cdtype, dtype=dtype]) raises: + ](mut self, index: Int, val: ComplexSIMD[Self.dtype]) raises: """ Safely stores SIMD element of size `width` at `index` of the underlying buffer. @@ -2209,7 +2209,7 @@ struct ComplexNDArray[ fn store[ width: Int = 1 - ](mut self, *indices: Int, val: ComplexSIMD[cdtype, dtype=dtype]) raises: + ](mut self, *indices: Int, val: ComplexSIMD[Self.dtype]) raises: """ Safely stores SIMD element of size `width` at given variadic indices of the underlying buffer. @@ -2242,40 +2242,42 @@ struct ComplexNDArray[ self._re._buf.ptr.store(idx, val.re) self._im._buf.ptr.store(idx, val.im) - # fn __iter__(self) raises -> _ComplexNDArrayIter[__origin_of(self._re), __origin_of(self._im), cdtype, dtype]: - # """Iterate over elements of the NDArray, returning copied value. + fn __iter__( + self, + ) raises -> _ComplexNDArrayIter[__origin_of(self._re), Self.dtype]: + """ + Iterates over elements of the ComplexNDArray and return sub-arrays as view. - # Returns: - # An iterator of NDArray elements. + Returns: + An iterator of ComplexNDArray elements. + """ - # Notes: - # Need to add lifetimes after the new release. - # """ + return _ComplexNDArrayIter[__origin_of(self._re), Self.dtype]( + self, + dimension=0, + ) - # return _ComplexNDArrayIter[__origin_of(self._re), __origin_of(self._im), cdtype, dtype]( - # array=self, - # length=self.shape[0], - # ) + fn __reversed__( + self, + ) raises -> _ComplexNDArrayIter[__origin_of(self._re), Self.dtype, forward=False]: + """ + Iterates backwards over elements of the ComplexNDArray, returning + copied value. - # fn __reversed__( - # self, - # ) raises -> _ComplexNDArrayIter[__origin_of(self._re), __origin_of(self._im), cdtype, dtype, forward=False]: - # """Iterate backwards over elements of the NDArray, returning - # copied value. + Returns: + A reversed iterator of NDArray elements. + """ - # Returns: - # A reversed iterator of NDArray elements. - # """ + return _ComplexNDArrayIter[__origin_of(self._re), Self.dtype, forward=False]( + self, + dimension=0, + ) - # # return _ComplexNDArrayIter[__origin_of(self._re), __origin_of(self._im), cdtype, dtype, forward=False]( - # # array=self, - # # length=self.shape[0], - # # ) fn itemset( mut self, index: Variant[Int, List[Int]], - item: ComplexSIMD[cdtype, dtype=dtype], + item: ComplexSIMD[Self.dtype], ) raises: """Set the scalar at the coordinates. @@ -2350,63 +2352,158 @@ struct ComplexNDArray[ raise Error("Invalid type: " + type + ", must be 're' or 'im'") -# @value -# struct _ComplexNDArrayIter[ -# is_mutable: Bool, //, -# origin: Origin[is_mutable], -# cdtype: CDType, -# dtype: DType, -# forward: Bool = True, -# ]: -# """ -# Iterator for NDArray. - -# Parameters: -# is_mutable: Whether the iterator is mutable. -# origin: The lifetime of the underlying NDArray data. -# cdtype: The complex data type of the item. -# dtype: The data type of the item. -# forward: The iteration direction. `False` is backwards. -# """ - -# var index: Int -# var array: ComplexNDArray[cdtype, dtype=dtype] -# var length: Int - -# fn __init__( -# mut self, -# array: ComplexNDArray[cdtype, dtype=dtype], -# length: Int, -# ): -# self.index = 0 if forward else length -# self.length = length -# self.array = array - -# fn __iter__(self) -> Self: -# return self - -# fn __next__(mut self) raises -> ComplexNDArray[cdtype, dtype=dtype]: -# @parameter -# if forward: -# var current_index = self.index -# self.index += 1 -# return self.array.__getitem__(current_index) -# else: -# var current_index = self.index -# self.index -= 1 -# return self.array.__getitem__(current_index) - -# @always_inline -# fn __has_next__(self) -> Bool: -# @parameter -# if forward: -# return self.index < self.length -# else: -# return self.index > 0 - -# fn __len__(self) -> Int: -# @parameter -# if forward: -# return self.length - self.index -# else: -# return self.index +@value +struct _ComplexNDArrayIter[ + is_mutable: Bool, //, + origin: Origin[is_mutable], + dtype: DType, + forward: Bool = True, +]: + # TODO: + # Return a view instead of copy where possible + # (when Bufferable is supported). + """ + An iterator yielding `ndim-1` array slices over the given dimension. + It is the default iterator of the `ComplexNDArray.__iter__() method and for loops. + It can also be constructed using the `ComplexNDArray.iter_over_dimension()` method. + It trys to create a view where possible. + + Parameters: + is_mutable: Whether the iterator is mutable. + origin: The lifetime of the underlying NDArray data. + dtype: The data type of the item. + forward: The iteration direction. `False` is backwards. + """ + + var index: Int + var re_ptr: UnsafePointer[Scalar[dtype]] + var im_ptr: UnsafePointer[Scalar[dtype]] + var dimension: Int + var length: Int + var shape: NDArrayShape + var strides: NDArrayStrides + """Strides of array or view. It is not necessarily compatible with shape.""" + var ndim: Int + var size_of_item: Int + + fn __init__(out self, read a: ComplexNDArray[dtype], read dimension: Int) raises: + """ + Initialize the iterator. + + Args: + a: The array + dimension: Dimension to iterate over. + """ + + if dimension < 0 or dimension >= a.ndim: + raise Error("Axis must be in the range of [0, ndim).") + + self.re_ptr = a._re._buf.ptr + self.im_ptr = a._im._buf.ptr + self.dimension = dimension + self.shape = a.shape + self.strides = a.strides + self.ndim = a.ndim + self.length = a.shape[dimension] + self.size_of_item = a.size // a.shape[dimension] + # Status of the iterator + self.index = 0 if forward else a.shape[dimension] - 1 + + fn __iter__(self) -> Self: + return self + + fn __next__(mut self) raises -> ComplexNDArray[dtype]: + var res = ComplexNDArray[dtype](self.shape._pop(self.dimension)) + var current_index = self.index + + @parameter + if forward: + self.index += 1 + else: + self.index -= 1 + + for offset in range(self.size_of_item): + var remainder = offset + var item = Item(ndim=self.ndim, initialized=False) + + for i in range(self.ndim - 1, -1, -1): + if i != self.dimension: + (item._buf + i).init_pointee_copy(remainder % self.shape[i]) + remainder = remainder // self.shape[i] + else: + (item._buf + self.dimension).init_pointee_copy( + current_index + ) + + (res._re._buf.ptr + offset).init_pointee_copy( + self.re_ptr[_get_offset(item, self.strides)] + ) + (res._im._buf.ptr + offset).init_pointee_copy( + self.im_ptr[_get_offset(item, self.strides)] + ) + return res + + @always_inline + fn __has_next__(self) -> Bool: + @parameter + if forward: + return self.index < self.length + else: + return self.index >= 0 + + fn __len__(self) -> Int: + @parameter + if forward: + return self.length - self.index + else: + return self.index + + fn ith(self, index: Int) raises -> ComplexNDArray[dtype]: + """ + Gets the i-th array of the iterator. + + Args: + index: The index of the item. It must be non-negative. + + Returns: + The i-th `ndim-1`-D array of the iterator. + """ + + if (index >= self.length) or (index < 0): + raise Error( + String( + "\nError in `ComplexNDArrayIter.ith()`: " + "Index ({}) must be in the range of [0, {})" + ).format(index, self.length) + ) + + if self.ndim > 1: + var res = ComplexNDArray[dtype](self.shape._pop(self.dimension)) + + for offset in range(self.size_of_item): + var remainder = offset + var item = Item(ndim=self.ndim, initialized=False) + + for i in range(self.ndim - 1, -1, -1): + if i != self.dimension: + (item._buf + i).init_pointee_copy( + remainder % self.shape[i] + ) + remainder = remainder // self.shape[i] + else: + (item._buf + self.dimension).init_pointee_copy(index) + + (res._re._buf.ptr + offset).init_pointee_copy( + self.re_ptr[_get_offset(item, self.strides)] + ) + (res._im._buf.ptr + offset).init_pointee_copy( + self.im_ptr[_get_offset(item, self.strides)] + ) + return res + + else: # 0-D array + var res = numojo.creation._0darray[dtype](ComplexSIMD[dtype]( + self.re_ptr[index], + self.im_ptr[index] + )) + return res \ No newline at end of file diff --git a/numojo/core/complex/complex_simd.mojo b/numojo/core/complex/complex_simd.mojo index ac25c7a0..ca5040bb 100644 --- a/numojo/core/complex/complex_simd.mojo +++ b/numojo/core/complex/complex_simd.mojo @@ -1,22 +1,18 @@ from math import sqrt -from .complex_dtype import CDType - -alias ComplexScalar = ComplexSIMD[_, size=1] - +alias ComplexScalar = ComplexSIMD[_, width=1] @register_passable("trivial") struct ComplexSIMD[ - cdtype: CDType, *, dtype: DType = CDType.to_dtype[cdtype](), size: Int = 1 + dtype: DType, width: Int = 1 ](): """ - Represents a SIMD[dtype, 1] Complex number with real and imaginary parts. + Represents a Complex number SIMD type with real and imaginary parts. """ - # FIELDS """The underlying data real and imaginary parts of the complex number.""" - var re: SIMD[dtype, size] - var im: SIMD[dtype, size] + var re: SIMD[dtype, width] + var im: SIMD[dtype, width] @always_inline fn __init__(out self, other: Self): @@ -29,7 +25,7 @@ struct ComplexSIMD[ self = other @always_inline - fn __init__(out self, re: SIMD[dtype, size], im: SIMD[dtype, size]): + fn __init__(out self, re: SIMD[Self.dtype, Self.width], im: SIMD[Self.dtype, Self.width]): """ Initializes a ComplexSIMD instance with specified real and imaginary parts. @@ -39,8 +35,8 @@ struct ComplexSIMD[ Example: ```mojo - var A = ComplexSIMD[cf32](SIMD[f32, 1](1.0), SIMD[f32, 1](2.0)) - var B = ComplexSIMD[cf32](SIMD[f32, 1](3.0), SIMD[f32, 1](4.0)) + var A = ComplexSIMD[f32](SIMD[f32, 1](1.0), SIMD[f32, 1](2.0)) + var B = ComplexSIMD[f32](SIMD[f32, 1](3.0), SIMD[f32, 1](4.0)) var C = A + B print(C) ``` @@ -50,7 +46,7 @@ struct ComplexSIMD[ self.im = im @always_inline - fn __init__(out self, val: SIMD[dtype, size]): + fn __init__(out self, val: SIMD[Self.dtype, Self.width]): """ Initializes a ComplexSIMD instance with specified real and imaginary parts. @@ -58,8 +54,8 @@ struct ComplexSIMD[ re: The real part of the complex number. im: The imaginary part of the complex number. """ - self.re = rebind[Scalar[dtype]](val) - self.im = rebind[Scalar[dtype]](val) + self.re = val + self.im = val fn __add__(self, other: Self) -> Self: """ @@ -171,7 +167,7 @@ struct ComplexSIMD[ """ return Self(self.re**other.re, self.im**other.im) - fn __pow__(self, other: Scalar[dtype]) -> Self: + fn __pow__(self, other: Scalar[Self.dtype]) -> Self: """ Raises this ComplexSIMD instance to the power of a scalar. @@ -260,9 +256,9 @@ struct ComplexSIMD[ Returns: String: The string representation of the ComplexSIMD instance. """ - return "ComplexSIMD[{}]({}, {})".format(String(dtype), self.re, self.im) + return "ComplexSIMD[{}]({}, {})".format(String(Self.dtype), self.re, self.im) - fn __getitem__(self, idx: Int) raises -> SIMD[dtype, size]: + fn __getitem__(self, idx: Int) raises -> SIMD[Self.dtype, Self.width]: """ Gets the real or imaginary part of the ComplexSIMD instance. @@ -280,6 +276,22 @@ struct ComplexSIMD[ else: raise Error("Index out of range") + fn __setitem__(mut self, idx: Int, value: SIMD[Self.dtype, Self.width]) raises: + """ + Sets the real and imaginary parts of the ComplexSIMD instance. + + Arguments: + self: The ComplexSIMD instance to modify. + idx: The index to access (0 for real, 1 for imaginary). + value: The new value to set. + """ + if idx == 0: + self.re = value + elif idx == 1: + self.im = value + else: + raise Error("Index out of range") + fn __setitem__(mut self, idx: Int, value: Self) raises: """ Sets the real and imaginary parts of the ComplexSIMD instance. @@ -296,28 +308,32 @@ struct ComplexSIMD[ else: raise Error("Index out of range") - fn __setitem__( - mut self, idx: Int, re: SIMD[dtype, size], im: SIMD[dtype, size] + fn item(self, idx: Int) raises -> SIMD[Self.dtype, Self.width]: + """ + Gets the real or imaginary part of the ComplexSIMD instance. + """ + return self[idx] + + fn itemset( + mut self, val: ComplexSIMD[Self.dtype, Self.width] ): """ Sets the real and imaginary parts of the ComplexSIMD instance. Arguments: self: The ComplexSIMD instance to modify. - idx: The index to access (0 for real, 1 for imaginary). - re: The new value for the real part. - im: The new value for the imaginary part. + val: The new value for the real and imaginary parts. """ - self.re = re - self.im = im + self.re = val.re + self.im = val.im - fn __abs__(self) -> SIMD[dtype, size]: + fn __abs__(self) -> SIMD[Self.dtype, Self.width]: """ Returns the magnitude of the ComplexSIMD instance. """ return sqrt(self.re * self.re + self.im * self.im) - fn norm(self) -> SIMD[dtype, size]: + fn norm(self) -> SIMD[Self.dtype, Self.width]: """ Returns the squared magnitude of the ComplexSIMD instance. """ @@ -329,14 +345,14 @@ struct ComplexSIMD[ """ return Self(self.re, -self.im) - fn real(self) -> SIMD[dtype, size]: + fn real(self) -> SIMD[Self.dtype, Self.width]: """ Returns the real part of the ComplexSIMD instance. """ return self.re - fn imag(self) -> SIMD[dtype, size]: + fn imag(self) -> SIMD[Self.dtype, Self.width]: """ Returns the imaginary part of the ComplexSIMD instance. """ - return self.im + return self.im \ No newline at end of file diff --git a/numojo/core/datatypes.mojo b/numojo/core/datatypes.mojo index 5ad1f4e8..c417f5dc 100644 --- a/numojo/core/datatypes.mojo +++ b/numojo/core/datatypes.mojo @@ -39,32 +39,6 @@ alias boolean = DType.bool # ===----------------------------------------------------------------------=== # -# Complex SIMD data type aliases -""" Data type alias for ComplexSIMD[DType.int8, 1] """ -alias ci8 = CDType.int8 -""" Data type alias for ComplexSIMD[DType.int16, 1] """ -alias ci16 = CDType.int16 -""" Data type alias for ComplexSIMD[DType.int32, 1] """ -alias ci32 = CDType.int32 -""" Data type alias for ComplexSIMD[DType.int64, 1] """ -alias ci64 = CDType.int64 -""" Data type alias for ComplexSIMD[DType.uint8, 1] """ -alias cu8 = CDType.uint8 -""" Data type alias for ComplexSIMD[DType.uint16, 1] """ -alias cu16 = CDType.uint16 -""" Data type alias for ComplexSIMD[DType.uint32, 1] """ -alias cu32 = CDType.uint32 -""" Data type alias for ComplexSIMD[DType.uint64, 1] """ -alias cu64 = CDType.uint64 -""" Data type alias for ComplexSIMD[DType.float16, 1] """ -alias cf16 = CDType.float16 -""" Data type alias for ComplexSIMD[DType.float32, 1] """ -alias cf32 = CDType.float32 -""" Data type alias for ComplexSIMD[DType.float64, 1] """ -alias cf64 = CDType.float64 - -# ===----------------------------------------------------------------------=== # - # TODO: Optimize the conditions with dict and move it to compile time # Dict can't be created at compile time rn diff --git a/numojo/prelude.mojo b/numojo/prelude.mojo index 20283f6e..ebe100f4 100644 --- a/numojo/prelude.mojo +++ b/numojo/prelude.mojo @@ -27,7 +27,6 @@ from numojo.core.matrix import Matrix from numojo.core.ndarray import NDArray from numojo.core.ndshape import Shape, NDArrayShape -from numojo.core.complex.complex_dtype import CDType from numojo.core.complex.complex_simd import ComplexSIMD, ComplexScalar from numojo.core.complex.complex_ndarray import ComplexNDArray @@ -46,15 +45,4 @@ from numojo.core.datatypes import ( f32, f64, boolean, - ci8, - ci16, - ci32, - ci64, - cu8, - cu16, - cu32, - cu64, - cf16, - cf32, - cf64, ) diff --git a/numojo/routines/creation.mojo b/numojo/routines/creation.mojo index 3a4c18fd..e49a7db0 100644 --- a/numojo/routines/creation.mojo +++ b/numojo/routines/creation.mojo @@ -100,14 +100,14 @@ fn arange[ fn arange[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() + dtype: DType = DType.float64, ]( - start: ComplexSIMD[cdtype, dtype=dtype], - stop: ComplexSIMD[cdtype, dtype=dtype], - step: ComplexSIMD[cdtype, dtype=dtype] = ComplexSIMD[cdtype, dtype=dtype]( + start: ComplexSIMD[dtype], + stop: ComplexSIMD[dtype], + step: ComplexSIMD[dtype] = ComplexSIMD[dtype]( 1, 1 ), -) raises -> ComplexNDArray[cdtype, dtype=dtype]: +) raises -> ComplexNDArray[dtype]: """ Function that computes a series of values starting from "start" to "stop" with given "step" size. @@ -117,13 +117,12 @@ fn arange[ dtype is an integer. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: - start: ComplexSIMD[cdtype] - Start value. - stop: ComplexSIMD[cdtype] - End value. - step: ComplexSIMD[cdtype] - Step size between each element (default 1). + start: ComplexSIMD[dtype] - Start value. + stop: ComplexSIMD[dtype] - End value. + step: ComplexSIMD[dtype] - Step size between each element (default 1). Returns: A ComplexNDArray of datatype `dtype` with elements ranging from `start` to `stop` incremented with `step`. @@ -136,13 +135,13 @@ fn arange[ num_re, num_im ) ) - var result: ComplexNDArray[cdtype, dtype=dtype] = ComplexNDArray[ - cdtype, dtype=dtype + var result: ComplexNDArray[dtype] = ComplexNDArray[ + dtype ](Shape(num_re)) for idx in range(num_re): result.store[width=1]( idx, - ComplexSIMD[cdtype, dtype=dtype]( + ComplexSIMD[dtype]( start.re + step.re * idx, start.im + step.im * idx ), ) @@ -151,9 +150,9 @@ fn arange[ fn arange[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](stop: ComplexSIMD[cdtype, dtype=dtype]) raises -> ComplexNDArray[ - cdtype, dtype=dtype + dtype: DType = DType.float64, +](stop: ComplexSIMD[dtype]) raises -> ComplexNDArray[ + dtype ]: """ (Overload) When start is 0 and step is 1. @@ -168,13 +167,13 @@ fn arange[ ) ) - var result: ComplexNDArray[cdtype, dtype=dtype] = ComplexNDArray[ - cdtype, dtype=dtype + var result: ComplexNDArray[dtype] = ComplexNDArray[ + dtype ](Shape(size_re)) for i in range(size_re): result.store[width=1]( i, - ComplexSIMD[cdtype, dtype=dtype]( + ComplexSIMD[dtype]( Scalar[dtype](i), Scalar[dtype](i) ), ) @@ -305,14 +304,14 @@ fn _linspace_parallel[ fn linspace[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() + dtype: DType = DType.float64, ]( - start: ComplexSIMD[cdtype, dtype=dtype], - stop: ComplexSIMD[cdtype, dtype=dtype], + start: ComplexSIMD[dtype], + stop: ComplexSIMD[dtype], num: Int = 50, endpoint: Bool = True, parallel: Bool = False, -) raises -> ComplexNDArray[cdtype, dtype=dtype]: +) raises -> ComplexNDArray[dtype]: """ Function that computes a series of linearly spaced values starting from "start" to "stop" with given size. Wrapper function for _linspace_serial, _linspace_parallel. @@ -320,8 +319,7 @@ fn linspace[ Error if dtype is an integer. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: start: Start value. @@ -336,27 +334,26 @@ fn linspace[ """ constrained[not dtype.is_integral()]() if parallel: - return _linspace_parallel[cdtype, dtype=dtype]( + return _linspace_parallel[dtype]( start, stop, num, endpoint ) else: - return _linspace_serial[cdtype, dtype=dtype](start, stop, num, endpoint) + return _linspace_serial[dtype](start, stop, num, endpoint) fn _linspace_serial[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() + dtype: DType = DType.float64, ]( - start: ComplexSIMD[cdtype, dtype=dtype], - stop: ComplexSIMD[cdtype, dtype=dtype], + start: ComplexSIMD[dtype], + stop: ComplexSIMD[dtype], num: Int, endpoint: Bool = True, -) raises -> ComplexNDArray[cdtype, dtype=dtype]: +) raises -> ComplexNDArray[dtype]: """ Generate a linearly spaced NDArray of `num` elements between `start` and `stop` using naive for loop. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: start: The starting value of the NDArray. @@ -367,8 +364,8 @@ fn _linspace_serial[ Returns: A ComplexNDArray of `dtype` with `num` linearly spaced elements between `start` and `stop`. """ - var result: ComplexNDArray[cdtype, dtype=dtype] = ComplexNDArray[ - cdtype, dtype=dtype + var result: ComplexNDArray[dtype] = ComplexNDArray[ + dtype ](Shape(num)) if endpoint: @@ -377,7 +374,7 @@ fn _linspace_serial[ for i in range(num): result.store[width=1]( i, - ComplexSIMD[cdtype, dtype=dtype]( + ComplexSIMD[dtype]( start.re + step_re * i, start.im + step_im * i ), ) @@ -388,7 +385,7 @@ fn _linspace_serial[ for i in range(num): result.store[width=1]( i, - ComplexSIMD[cdtype, dtype=dtype]( + ComplexSIMD[dtype]( start.re + step_re * i, start.im + step_im * i ), ) @@ -397,19 +394,18 @@ fn _linspace_serial[ fn _linspace_parallel[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() + dtype: DType = DType.float64, ]( - start: ComplexSIMD[cdtype, dtype=dtype], - stop: ComplexSIMD[cdtype, dtype=dtype], + start: ComplexSIMD[dtype], + stop: ComplexSIMD[dtype], num: Int, endpoint: Bool = True, -) raises -> ComplexNDArray[cdtype, dtype=dtype]: +) raises -> ComplexNDArray[dtype]: """ Generate a linearly spaced ComplexNDArray of `num` elements between `start` and `stop` using parallelization. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: start: The starting value of the ComplexNDArray. @@ -420,8 +416,8 @@ fn _linspace_parallel[ Returns: A ComplexNDArray of `dtype` with `num` linearly spaced elements between `start` and `stop`. """ - var result: ComplexNDArray[cdtype, dtype=dtype] = ComplexNDArray[ - cdtype, dtype=dtype + var result: ComplexNDArray[dtype] = ComplexNDArray[ + dtype ](Shape(num)) alias nelts = simdwidthof[dtype]() @@ -436,7 +432,7 @@ fn _linspace_parallel[ try: result.store[width=1]( idx, - ComplexSIMD[cdtype, dtype=dtype]( + ComplexSIMD[dtype]( start.re + step_re * idx, start.im + step_im * idx ), ) @@ -454,7 +450,7 @@ fn _linspace_parallel[ try: result.store[width=1]( idx, - ComplexSIMD[cdtype, dtype=dtype]( + ComplexSIMD[dtype]( start.re + step_re * idx, start.im + step_im * idx ), ) @@ -605,17 +601,17 @@ fn _logspace_parallel[ fn logspace[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() + dtype: DType = DType.float64, ]( - start: ComplexSIMD[cdtype, dtype=dtype], - stop: ComplexSIMD[cdtype, dtype=dtype], + start: ComplexSIMD[dtype], + stop: ComplexSIMD[dtype], num: Int, endpoint: Bool = True, - base: ComplexSIMD[cdtype, dtype=dtype] = ComplexSIMD[cdtype, dtype=dtype]( + base: ComplexSIMD[dtype] = ComplexSIMD[dtype]( 10.0, 10.0 ), parallel: Bool = False, -) raises -> ComplexNDArray[cdtype, dtype=dtype]: +) raises -> ComplexNDArray[dtype]: """ Generate a logrithmic spaced ComplexNDArray of `num` elements between `start` and `stop`. Wrapper function for _logspace_serial, _logspace_parallel functions. @@ -623,8 +619,7 @@ fn logspace[ Error if dtype is an integer. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: start: The starting value of the ComplexNDArray. @@ -639,7 +634,7 @@ fn logspace[ """ constrained[not dtype.is_integral()]() if parallel: - return _logspace_parallel[cdtype, dtype=dtype]( + return _logspace_parallel[dtype]( start, stop, num, @@ -647,7 +642,7 @@ fn logspace[ endpoint, ) else: - return _logspace_serial[cdtype, dtype=dtype]( + return _logspace_serial[dtype]( start, stop, num, @@ -657,20 +652,19 @@ fn logspace[ fn _logspace_serial[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() + dtype: DType = DType.float64, ]( - start: ComplexSIMD[cdtype, dtype=dtype], - stop: ComplexSIMD[cdtype, dtype=dtype], + start: ComplexSIMD[dtype], + stop: ComplexSIMD[dtype], num: Int, - base: ComplexSIMD[cdtype, dtype=dtype], + base: ComplexSIMD[dtype], endpoint: Bool = True, -) raises -> ComplexNDArray[cdtype, dtype=dtype]: +) raises -> ComplexNDArray[dtype]: """ Generate a logarithmic spaced ComplexNDArray of `num` elements between `start` and `stop` using naive for loop. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: start: The starting value of the ComplexNDArray. @@ -682,8 +676,8 @@ fn _logspace_serial[ Returns: A ComplexNDArray of `dtype` with `num` logarithmic spaced elements between `start` and `stop`. """ - var result: ComplexNDArray[cdtype, dtype=dtype] = ComplexNDArray[ - cdtype, dtype=dtype + var result: ComplexNDArray[dtype] = ComplexNDArray[ + dtype ](NDArrayShape(num)) if endpoint: @@ -692,7 +686,7 @@ fn _logspace_serial[ for i in range(num): result.store[1]( i, - ComplexSIMD[cdtype, dtype=dtype]( + ComplexSIMD[dtype]( base.re ** (start.re + step_re * i), base.im ** (start.im + step_im * i), ), @@ -703,7 +697,7 @@ fn _logspace_serial[ for i in range(num): result.store[1]( i, - ComplexSIMD[cdtype, dtype=dtype]( + ComplexSIMD[dtype]( base.re ** (start.re + step_re * i), base.im ** (start.im + step_im * i), ), @@ -712,20 +706,19 @@ fn _logspace_serial[ fn _logspace_parallel[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() + dtype: DType = DType.float64, ]( - start: ComplexSIMD[cdtype, dtype=dtype], - stop: ComplexSIMD[cdtype, dtype=dtype], + start: ComplexSIMD[dtype], + stop: ComplexSIMD[dtype], num: Int, - base: ComplexSIMD[cdtype, dtype=dtype], + base: ComplexSIMD[dtype], endpoint: Bool = True, -) raises -> ComplexNDArray[cdtype, dtype=dtype]: +) raises -> ComplexNDArray[dtype]: """ Generate a logarithmic spaced ComplexNDArray of `num` elements between `start` and `stop` using parallelization. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: start: The starting value of the ComplexNDArray. @@ -737,8 +730,8 @@ fn _logspace_parallel[ Returns: A ComplexNDArray of `dtype` with `num` logarithmic spaced elements between `start` and `stop`. """ - var result: ComplexNDArray[cdtype, dtype=dtype] = ComplexNDArray[ - cdtype, dtype=dtype + var result: ComplexNDArray[dtype] = ComplexNDArray[ + dtype ](NDArrayShape(num)) if endpoint: @@ -750,7 +743,7 @@ fn _logspace_parallel[ try: result.store[1]( idx, - ComplexSIMD[cdtype, dtype=dtype]( + ComplexSIMD[dtype]( base.re ** (start.re + step_re * idx), base.im ** (start.im + step_im * idx), ), @@ -769,7 +762,7 @@ fn _logspace_parallel[ try: result.store[1]( idx, - ComplexSIMD[cdtype, dtype=dtype]( + ComplexSIMD[dtype]( base.re ** (start.re + step_re * idx), base.im ** (start.im + step_im * idx), ), @@ -834,13 +827,13 @@ fn geomspace[ fn geomspace[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() + dtype: DType = DType.float64, ]( - start: ComplexSIMD[cdtype, dtype=dtype], - stop: ComplexSIMD[cdtype, dtype=dtype], + start: ComplexSIMD[dtype], + stop: ComplexSIMD[dtype], num: Int, endpoint: Bool = True, -) raises -> ComplexNDArray[cdtype, dtype=dtype]: +) raises -> ComplexNDArray[dtype]: """ Generate a ComplexNDArray of `num` elements between `start` and `stop` in a geometric series. @@ -848,8 +841,7 @@ fn geomspace[ Error if dtype is an integer. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: start: The starting value of the ComplexNDArray. @@ -863,35 +855,35 @@ fn geomspace[ constrained[ not dtype.is_integral(), "Int type will result to precision errors." ]() - var a: ComplexSIMD[cdtype, dtype=dtype] = start + var a: ComplexSIMD[dtype] = start if endpoint: - var result: ComplexNDArray[cdtype, dtype=dtype] = ComplexNDArray[ - cdtype, dtype=dtype + var result: ComplexNDArray[dtype] = ComplexNDArray[ + dtype ](NDArrayShape(num)) - var base: ComplexSIMD[cdtype, dtype=dtype] = (stop / start) + var base: ComplexSIMD[dtype] = (stop / start) var power: Scalar[dtype] = 1 / Scalar[dtype](num - 1) - var r: ComplexSIMD[cdtype, dtype=dtype] = base**power + var r: ComplexSIMD[dtype] = base**power for i in range(num): result.store[1]( i, - ComplexSIMD[cdtype, dtype=dtype]( + ComplexSIMD[dtype]( a.re * r.re**i, a.im * r.im**i ), ) return result^ else: - var result: ComplexNDArray[cdtype, dtype=dtype] = ComplexNDArray[ - cdtype, dtype=dtype + var result: ComplexNDArray[dtype] = ComplexNDArray[ + dtype ](NDArrayShape(num)) - var base: ComplexSIMD[cdtype, dtype=dtype] = (stop / start) + var base: ComplexSIMD[dtype] = (stop / start) var power: Scalar[dtype] = 1 / Scalar[dtype](num) - var r: ComplexSIMD[cdtype, dtype=dtype] = base**power + var r: ComplexSIMD[dtype] = base**power for i in range(num): result.store[1]( i, - ComplexSIMD[cdtype, dtype=dtype]( + ComplexSIMD[dtype]( a.re * r.re**i, a.im * r.im**i ), ) @@ -951,50 +943,49 @@ fn empty_like[ return NDArray[dtype](shape=array.shape) -fn empty[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](shape: NDArrayShape) raises -> ComplexNDArray[cdtype, dtype=dtype]: - """ - Generate an empty ComplexNDArray of given shape with arbitrary values. +# fn empty[ +# dtype: DType = DType.float64, +# ](shape: NDArrayShape) raises -> ComplexNDArray[dtype]: +# """ +# Generate an empty ComplexNDArray of given shape with arbitrary values. - Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. +# Parameters: +# dtype: Complex datatype of the output array. +# dtype: Equivalent real datatype of the output array. - Args: - shape: Shape of the ComplexNDArray. +# Args: +# shape: Shape of the ComplexNDArray. - Returns: - A ComplexNDArray of `dtype` with given `shape`. - """ - return ComplexNDArray[cdtype, dtype=dtype](shape=shape) +# Returns: +# A ComplexNDArray of `dtype` with given `shape`. +# """ +# return ComplexNDArray[dtype](shape=shape) -fn empty[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](shape: List[Int]) raises -> ComplexNDArray[cdtype, dtype=dtype]: - """Overload of function `empty` that reads a list of ints.""" - return empty[cdtype, dtype=dtype](shape=NDArrayShape(shape)) +# fn empty[ +# dtype: DType = DType.float64, +# ](shape: List[Int]) raises -> ComplexNDArray[dtype]: +# """Overload of function `empty` that reads a list of ints.""" +# return empty[dtype](shape=NDArrayShape(shape)) -fn empty[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](shape: VariadicList[Int]) raises -> ComplexNDArray[cdtype, dtype=dtype]: - """Overload of function `empty` that reads a variadic list of ints.""" - return empty[cdtype, dtype=dtype](shape=NDArrayShape(shape)) +# fn empty[ +# dtype: DType = DType.float64, +# ](shape: VariadicList[Int]) raises -> ComplexNDArray[dtype]: +# """Overload of function `empty` that reads a variadic list of ints.""" +# return empty[dtype](shape=NDArrayShape(shape)) fn empty_like[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](array: ComplexNDArray[cdtype, dtype=dtype]) raises -> ComplexNDArray[ - cdtype, dtype=dtype + dtype: DType = DType.float64, +](array: ComplexNDArray[dtype]) raises -> ComplexNDArray[ + dtype ]: """ Generate an empty ComplexNDArray of the same shape as `array`. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: array: ComplexNDArray to be used as a reference for the shape. @@ -1002,7 +993,7 @@ fn empty_like[ Returns: A ComplexNDArray of `dtype` with the same shape as `array`. """ - return ComplexNDArray[cdtype, dtype=dtype](shape=array.shape) + return ComplexNDArray[dtype](shape=array.shape) fn eye[dtype: DType = DType.float64](N: Int, M: Int) raises -> NDArray[dtype]: @@ -1026,32 +1017,32 @@ fn eye[dtype: DType = DType.float64](N: Int, M: Int) raises -> NDArray[dtype]: return result^ -fn eye[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](N: Int, M: Int) raises -> ComplexNDArray[cdtype, dtype=dtype]: - """ - Return a 2-D ComplexNDArray with ones on the diagonal and zeros elsewhere. +# fn eye[ +# dtype: DType = DType.float64, +# ](N: Int, M: Int) raises -> ComplexNDArray[dtype]: +# """ +# Return a 2-D ComplexNDArray with ones on the diagonal and zeros elsewhere. - Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. +# Parameters: +# dtype: Complex datatype of the output array. +# dtype: Equivalent real datatype of the output array. - Args: - N: Number of rows in the matrix. - M: Number of columns in the matrix. +# Args: +# N: Number of rows in the matrix. +# M: Number of columns in the matrix. - Returns: - A ComplexNDArray of `dtype` with size N x M and ones on the diagonals. - """ - var result: ComplexNDArray[cdtype, dtype=dtype] = zeros[ - cdtype, dtype=dtype - ](NDArrayShape(N, M)) - var one: ComplexSIMD[cdtype, dtype=dtype] = ComplexSIMD[ - cdtype, dtype=dtype - ](1, 1) - for i in range(min(N, M)): - result.store[1](i, i, val=one) - return result^ +# Returns: +# A ComplexNDArray of `dtype` with size N x M and ones on the diagonals. +# """ +# var result: ComplexNDArray[dtype] = zeros[ +# dtype +# ](NDArrayShape(N, M)) +# var one: ComplexSIMD[dtype] = ComplexSIMD[ +# dtype +# ](1, 1) +# for i in range(min(N, M)): +# result.store[1](i, i, val=one) +# return result^ fn identity[dtype: DType = DType.float64](N: Int) raises -> NDArray[dtype]: @@ -1074,31 +1065,31 @@ fn identity[dtype: DType = DType.float64](N: Int) raises -> NDArray[dtype]: return result^ -fn identity[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](N: Int) raises -> ComplexNDArray[cdtype, dtype=dtype]: - """ - Generate an Complex identity matrix of size N x N. +# fn identity[ +# dtype: DType = DType.float64, +# ](N: Int) raises -> ComplexNDArray[dtype]: +# """ +# Generate an Complex identity matrix of size N x N. - Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. +# Parameters: +# dtype: Complex datatype of the output array. +# dtype: Equivalent real datatype of the output array. - Args: - N: Size of the matrix. +# Args: +# N: Size of the matrix. - Returns: - A ComplexNDArray of `dtype` with size N x N and ones on the diagonals. - """ - var result: ComplexNDArray[cdtype, dtype=dtype] = zeros[ - cdtype, dtype=dtype - ](NDArrayShape(N, N)) - var one: ComplexSIMD[cdtype, dtype=dtype] = ComplexSIMD[ - cdtype, dtype=dtype - ](1, 1) - for i in range(N): - result.store[1](i, i, val=one) - return result^ +# Returns: +# A ComplexNDArray of `dtype` with size N x N and ones on the diagonals. +# """ +# var result: ComplexNDArray[dtype] = zeros[ +# dtype +# ](NDArrayShape(N, N)) +# var one: ComplexSIMD[dtype] = ComplexSIMD[ +# dtype +# ](1, 1) +# for i in range(N): +# result.store[1](i, i, val=one) +# return result^ fn ones[ @@ -1153,54 +1144,53 @@ fn ones_like[ return ones[dtype](shape=array.shape) -fn ones[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](shape: NDArrayShape) raises -> ComplexNDArray[cdtype, dtype=dtype]: - """ - Generate a ComplexNDArray of ones with given shape filled with ones. +# fn ones[ +# dtype: DType = DType.float64, +# ](shape: NDArrayShape) raises -> ComplexNDArray[dtype]: +# """ +# Generate a ComplexNDArray of ones with given shape filled with ones. - It calls the function `full`. +# It calls the function `full`. - Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. +# Parameters: +# dtype: Complex datatype of the output array. +# dtype: Equivalent real datatype of the output array. - Args: - shape: Shape of the ComplexNDArray. +# Args: +# shape: Shape of the ComplexNDArray. - Returns: - A ComplexNDArray of `dtype` with given `shape`. - """ - return full[cdtype, dtype=dtype]( - shape=shape, fill_value=ComplexSIMD[cdtype, dtype=dtype](1, 1) - ) +# Returns: +# A ComplexNDArray of `dtype` with given `shape`. +# """ +# return full[dtype]( +# shape=shape, fill_value=ComplexSIMD[dtype](1, 1) +# ) -fn ones[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](shape: List[Int]) raises -> ComplexNDArray[cdtype, dtype=dtype]: - """Overload of function `ones` that reads a list of ints.""" - return ones[cdtype, dtype=dtype](shape=NDArrayShape(shape)) +# fn ones[ +# dtype: DType = DType.float64, +# ](shape: List[Int]) raises -> ComplexNDArray[dtype]: +# """Overload of function `ones` that reads a list of ints.""" +# return ones[dtype](shape=NDArrayShape(shape)) -fn ones[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](shape: VariadicList[Int]) raises -> ComplexNDArray[cdtype, dtype=dtype]: - """Overload of function `ones` that reads a variadic of ints.""" - return ones[cdtype, dtype=dtype](shape=NDArrayShape(shape)) +# fn ones[ +# dtype: DType = DType.float64, +# ](shape: VariadicList[Int]) raises -> ComplexNDArray[dtype]: +# """Overload of function `ones` that reads a variadic of ints.""" +# return ones[dtype](shape=NDArrayShape(shape)) fn ones_like[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](array: ComplexNDArray[cdtype, dtype=dtype]) raises -> ComplexNDArray[ - cdtype, dtype=dtype + dtype: DType = DType.float64, +](array: ComplexNDArray[dtype]) raises -> ComplexNDArray[ + dtype ]: """ Generate a ComplexNDArray of the same shape as `a` filled with ones. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: array: ComplexNDArray to be used as a reference for the shape. @@ -1208,7 +1198,7 @@ fn ones_like[ Returns: A ComplexNDArray of `dtype` with the same shape as `a` filled with ones. """ - return ones[cdtype, dtype=dtype](shape=array.shape) + return full[dtype](shape=array.shape, fill_value=ComplexSIMD[dtype](1, 1)) fn zeros[ @@ -1265,55 +1255,54 @@ fn zeros_like[ return full[dtype](shape=array.shape, fill_value=0) -fn zeros[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](shape: NDArrayShape) raises -> ComplexNDArray[cdtype, dtype=dtype]: - """ - Generate a ComplexNDArray of zeros with given shape. +# fn zeros[ +# dtype: DType = DType.float64, +# ](shape: NDArrayShape) raises -> ComplexNDArray[dtype]: +# """ +# Generate a ComplexNDArray of zeros with given shape. - It calls the function `full`. +# It calls the function `full`. - Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. +# Parameters: +# dtype: Complex datatype of the output array. +# dtype: Equivalent real datatype of the output array. - Args: - shape: Shape of the ComplexNDArray. +# Args: +# shape: Shape of the ComplexNDArray. - Returns: - A ComplexNDArray of `dtype` with given `shape`. +# Returns: +# A ComplexNDArray of `dtype` with given `shape`. - """ - return full[cdtype, dtype=dtype]( - shape=shape, fill_value=ComplexSIMD[cdtype, dtype=dtype](0, 0) - ) +# """ +# return full[dtype]( +# shape=shape, fill_value=ComplexSIMD[dtype](0, 0) +# ) -fn zeros[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](shape: List[Int]) raises -> ComplexNDArray[cdtype, dtype=dtype]: - """Overload of function `zeros` that reads a list of ints.""" - return zeros[cdtype, dtype=dtype](shape=NDArrayShape(shape)) +# fn zeros[ +# dtype: DType = DType.float64, +# ](shape: List[Int]) raises -> ComplexNDArray[dtype]: +# """Overload of function `zeros` that reads a list of ints.""" +# return zeros[dtype](shape=NDArrayShape(shape)) -fn zeros[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](shape: VariadicList[Int]) raises -> ComplexNDArray[cdtype, dtype=dtype]: - """Overload of function `zeros` that reads a variadic list of ints.""" - return zeros[cdtype, dtype=dtype](shape=NDArrayShape(shape)) +# fn zeros[ +# dtype: DType = DType.float64, +# ](shape: VariadicList[Int]) raises -> ComplexNDArray[dtype]: +# """Overload of function `zeros` that reads a variadic list of ints.""" +# return zeros[dtype](shape=NDArrayShape(shape)) fn zeros_like[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](array: ComplexNDArray[cdtype, dtype=dtype]) raises -> ComplexNDArray[ - cdtype, dtype=dtype + dtype: DType = DType.float64, +](array: ComplexNDArray[dtype]) raises -> ComplexNDArray[ + dtype ]: """ Generate a ComplexNDArray of the same shape as `a` filled with zeros. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: array: ComplexNDArray to be used as a reference for the shape. @@ -1321,8 +1310,8 @@ fn zeros_like[ Returns: A ComplexNDArray of `dtype` with the same shape as `a` filled with zeros. """ - return full[cdtype, dtype=dtype]( - shape=array.shape, fill_value=ComplexSIMD[cdtype, dtype=dtype](0, 0) + return full[dtype]( + shape=array.shape, fill_value=ComplexSIMD[dtype](0, 0) ) @@ -1394,20 +1383,18 @@ fn full_like[ A NDArray of `dtype` with the same shape as `a` filled with `fill_value`. """ return full[dtype](shape=array.shape, fill_value=fill_value, order=order) - - + fn full[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() + dtype: DType = DType.float64 ]( shape: NDArrayShape, - fill_value: ComplexSIMD[cdtype, dtype=dtype], + fill_value: ComplexSIMD[dtype], order: String = "C", -) raises -> ComplexNDArray[cdtype, dtype=dtype]: +) raises -> ComplexNDArray[dtype]: """Initialize an ComplexNDArray of certain shape fill it with a given value. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: shape: Shape of the ComplexNDArray. @@ -1421,8 +1408,7 @@ fn full[ var a = nm.full[cf32](Shape(2,3,4), fill_value=ComplexSIMD[cf32](10, 10)) ``` """ - - var A = ComplexNDArray[cdtype, dtype=dtype](shape=shape, order=order) + var A = ComplexNDArray[dtype](shape=shape, order=order) for i in range(A.size): A._re._buf.ptr.store(i, fill_value.re) A._im._buf.ptr.store(i, fill_value.im) @@ -1430,44 +1416,43 @@ fn full[ fn full[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() + dtype: DType = DType.float64 ]( shape: List[Int], - fill_value: ComplexSIMD[cdtype, dtype=dtype], + fill_value: ComplexSIMD[dtype], order: String = "C", -) raises -> ComplexNDArray[cdtype, dtype=dtype]: +) raises -> ComplexNDArray[dtype]: """Overload of function `full` that reads a list of ints.""" - return full[cdtype, dtype=dtype]( + return full[dtype]( shape=NDArrayShape(shape), fill_value=fill_value, order=order ) fn full[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() + dtype: DType = DType.float64 ]( shape: VariadicList[Int], - fill_value: ComplexSIMD[cdtype, dtype=dtype], + fill_value: ComplexSIMD[dtype], order: String = "C", -) raises -> ComplexNDArray[cdtype, dtype=dtype]: +) raises -> ComplexNDArray[dtype]: """Overload of function `full` that reads a variadic list of ints.""" - return full[cdtype, dtype=dtype]( + return full[dtype]( shape=NDArrayShape(shape), fill_value=fill_value, order=order ) fn full_like[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() + dtype: DType = DType.float64 ]( - array: ComplexNDArray[cdtype, dtype=dtype], - fill_value: ComplexSIMD[cdtype, dtype=dtype], + array: ComplexNDArray[dtype], + fill_value: ComplexSIMD[dtype], order: String = "C", -) raises -> ComplexNDArray[cdtype, dtype=dtype]: +) raises -> ComplexNDArray[dtype]: """ Generate a ComplexNDArray of the same shape as `a` filled with `fill_value`. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: array: ComplexNDArray to be used as a reference for the shape. @@ -1477,7 +1462,7 @@ fn full_like[ Returns: A ComplexNDArray of `dtype` with the same shape as `a` filled with `fill_value`. """ - return full[cdtype, dtype=dtype]( + return full[dtype]( shape=array.shape, fill_value=fill_value, order=order ) @@ -1534,16 +1519,15 @@ fn diag[ fn diag[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](v: ComplexNDArray[cdtype, dtype=dtype], k: Int = 0) raises -> ComplexNDArray[ - cdtype, dtype=dtype + dtype: DType = DType.float64, +](v: ComplexNDArray[dtype], k: Int = 0) raises -> ComplexNDArray[ + dtype ]: """ Extract a diagonal or construct a diagonal ComplexNDArray. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: v: ComplexNDArray to extract the diagonal from. @@ -1552,7 +1536,7 @@ fn diag[ Returns: A 1-D ComplexNDArray with the diagonal of the input ComplexNDArray. """ - return ComplexNDArray[cdtype, dtype=dtype]( + return ComplexNDArray[dtype]( re=diag[dtype](v._re, k), im=diag[dtype](v._im, k), ) @@ -1591,16 +1575,15 @@ fn diagflat[ fn diagflat[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](v: ComplexNDArray[cdtype, dtype=dtype], k: Int = 0) raises -> ComplexNDArray[ - cdtype, dtype=dtype + dtype: DType = DType.float64, +](v: ComplexNDArray[dtype], k: Int = 0) raises -> ComplexNDArray[ + dtype ]: """ Generate a 2-D ComplexNDArray with the flattened input as the diagonal. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: v: ComplexNDArray to be flattened and used as the diagonal. @@ -1609,7 +1592,7 @@ fn diagflat[ Returns: A 2-D ComplexNDArray with the flattened input as the diagonal. """ - return ComplexNDArray[cdtype, dtype=dtype]( + return ComplexNDArray[dtype]( re=diagflat[dtype](v._re, k), im=diagflat[dtype](v._im, k), ) @@ -1640,28 +1623,28 @@ fn tri[ return result^ -fn tri[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](N: Int, M: Int, k: Int = 0) raises -> ComplexNDArray[cdtype, dtype=dtype]: - """ - Generate a 2-D ComplexNDArray with ones on and below the k-th diagonal. +# fn tri[ +# dtype: DType = DType.float64, +# ](N: Int, M: Int, k: Int = 0) raises -> ComplexNDArray[dtype]: +# """ +# Generate a 2-D ComplexNDArray with ones on and below the k-th diagonal. - Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. +# Parameters: +# dtype: Complex datatype of the output array. +# dtype: Equivalent real datatype of the output array. - Args: - N: Number of rows in the matrix. - M: Number of columns in the matrix. - k: Diagonal offset. +# Args: +# N: Number of rows in the matrix. +# M: Number of columns in the matrix. +# k: Diagonal offset. - Returns: - A 2-D ComplexNDArray with ones on and below the k-th diagonal. - """ - return ComplexNDArray[cdtype, dtype=dtype]( - re=tri[dtype](N, M, k), - im=tri[dtype](N, M, k), - ) +# Returns: +# A 2-D ComplexNDArray with ones on and below the k-th diagonal. +# """ +# return ComplexNDArray[dtype]( +# re=tri[dtype](N, M, k), +# im=tri[dtype](N, M, k), +# ) fn tril[ @@ -1707,16 +1690,15 @@ fn tril[ fn tril[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](m: ComplexNDArray[cdtype, dtype=dtype], k: Int = 0) raises -> ComplexNDArray[ - cdtype, dtype=dtype + dtype: DType = DType.float64, +](m: ComplexNDArray[dtype], k: Int = 0) raises -> ComplexNDArray[ + dtype ]: """ Zero out elements above the k-th diagonal. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: m: ComplexNDArray to be zeroed out. @@ -1725,7 +1707,7 @@ fn tril[ Returns: A ComplexNDArray with elements above the k-th diagonal zeroed out. """ - return ComplexNDArray[cdtype, dtype=dtype]( + return ComplexNDArray[dtype]( re=tril[dtype](m._re, k), im=tril[dtype](m._im, k), ) @@ -1774,16 +1756,15 @@ fn triu[ fn triu[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() -](m: ComplexNDArray[cdtype, dtype=dtype], k: Int = 0) raises -> ComplexNDArray[ - cdtype, dtype=dtype + dtype: DType = DType.float64, +](m: ComplexNDArray[dtype], k: Int = 0) raises -> ComplexNDArray[ + dtype ]: """ Zero out elements below the k-th diagonal. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: m: ComplexNDArray to be zeroed out. @@ -1792,7 +1773,7 @@ fn triu[ Returns: A ComplexNDArray with elements below the k-th diagonal zeroed out. """ - return ComplexNDArray[cdtype, dtype=dtype]( + return ComplexNDArray[dtype]( re=triu[dtype](m._re, k), im=triu[dtype](m._im, k), ) @@ -1835,18 +1816,17 @@ fn vander[ fn vander[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() + dtype: DType = DType.float64, ]( - x: ComplexNDArray[cdtype, dtype=dtype], + x: ComplexNDArray[dtype], N: Optional[Int] = None, increasing: Bool = False, -) raises -> ComplexNDArray[cdtype, dtype=dtype]: +) raises -> ComplexNDArray[dtype]: """ Generate a Complex Vandermonde matrix. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: x: 1-D input array. @@ -1856,7 +1836,7 @@ fn vander[ Returns: A Complex Vandermonde matrix. """ - return ComplexNDArray[cdtype, dtype=dtype]( + return ComplexNDArray[dtype]( re=vander[dtype](x._re, N, increasing), im=vander[dtype](x._im, N, increasing), ) @@ -1931,21 +1911,17 @@ fn astype[ fn astype[ - cdtype: CDType, //, - target: CDType, - dtype: DType = CDType.to_dtype[cdtype](), - target_dtype: DType = CDType.to_dtype[cdtype](), -](a: ComplexNDArray[cdtype, dtype=dtype]) raises -> ComplexNDArray[ - target, dtype=target_dtype + dtype: DType, //, + target: DType, +](a: ComplexNDArray[dtype]) raises -> ComplexNDArray[ + target ]: """ Cast a ComplexNDArray to a different dtype. Parameters: - cdtype: Complex datatype of the input array. + dtype: Complex datatype of the input array. target: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. - target_dtype: Equivalent real datatype of the output array. Args: a: ComplexNDArray to be casted. @@ -1954,12 +1930,11 @@ fn astype[ A ComplexNDArray with the same shape and strides as `a` but with elements casted to `target`. """ - return ComplexNDArray[target, dtype=target_dtype]( - re=astype[target_dtype](a._re), - im=astype[target_dtype](a._im), + return ComplexNDArray[target]( + re=astype[target](a._re), + im=astype[target](a._im), ) - # ===------------------------------------------------------------------------===# # Construct array from other objects # ===------------------------------------------------------------------------===# @@ -2143,19 +2118,18 @@ fn array[ fn array[ - cdtype: CDType = CDType.float64, *, dtype: DType = CDType.to_dtype[cdtype]() + dtype: DType = DType.float64, ]( real: List[Scalar[dtype]], imag: List[Scalar[dtype]], shape: List[Int], order: String = "C", -) raises -> ComplexNDArray[cdtype, dtype=dtype]: +) raises -> ComplexNDArray[dtype]: """ Array creation with given data, shape and order. Parameters: - cdtype: Complex datatype of the output array. - dtype: Equivalent real datatype of the output array. + dtype: Complex datatype of the output array. Args: real: List of real data. @@ -2183,7 +2157,7 @@ fn array[ .format(len(real), len(imag)) ) - A = ComplexNDArray[cdtype, dtype=dtype](shape=shape, order=order) + A = ComplexNDArray[dtype](shape=shape, order=order) for i in range(A.size): A._re._buf.ptr[i] = real[i] @@ -2297,11 +2271,11 @@ fn _0darray[ fn _0darray[ - cdtype: CDType, *, dtype: DType = CDType.to_dtype[cdtype]() + dtype: DType, ]( - val: ComplexSIMD[cdtype, dtype=dtype], + val: ComplexSIMD[dtype], ) raises -> ComplexNDArray[ - cdtype, dtype=dtype + dtype ]: """ Initialize an special 0d complexarray (numojo scalar). @@ -2311,7 +2285,7 @@ fn _0darray[ The size is 1 (`=0!`). """ - var b = ComplexNDArray[cdtype, dtype=dtype]( + var b = ComplexNDArray[dtype]( shape=NDArrayShape(ndim=0, initialized=False), strides=NDArrayStrides(ndim=0, initialized=False), ndim=0, diff --git a/numojo/routines/io/formatting.mojo b/numojo/routines/io/formatting.mojo index 32295a83..3138093b 100644 --- a/numojo/routines/io/formatting.mojo +++ b/numojo/routines/io/formatting.mojo @@ -346,9 +346,9 @@ fn format_floating_precision[ fn format_floating_precision[ - cdtype: CDType, dtype: DType + dtype: DType ]( - value: ComplexSIMD[cdtype, dtype=dtype], + value: ComplexSIMD[dtype], precision: Int = 4, sign: Bool = False, ) raises -> String: @@ -420,9 +420,9 @@ fn format_value[ fn format_value[ - cdtype: CDType, dtype: DType + dtype: DType ]( - value: ComplexSIMD[cdtype, dtype=dtype], + value: ComplexSIMD[dtype], print_options: PrintOptions, ) raises -> String: """ diff --git a/tests/core/test_complexArray.mojo b/tests/core/test_complexArray.mojo index 1c260695..e76d69de 100644 --- a/tests/core/test_complexArray.mojo +++ b/tests/core/test_complexArray.mojo @@ -6,27 +6,27 @@ from numojo import * fn test_complex_array_init() raises: """Test initialization of ComplexArray.""" - var c1 = ComplexNDArray[cf32](Shape(2, 2)) - c1.itemset(0, ComplexSIMD[cf32](1.0, 2.0)) - c1.itemset(1, ComplexSIMD[cf32](3.0, 4.0)) - c1.itemset(2, ComplexSIMD[cf32](5.0, 6.0)) - c1.itemset(3, ComplexSIMD[cf32](7.0, 8.0)) + var c1 = ComplexNDArray[f32](Shape(2, 2)) + c1.itemset(0, ComplexSIMD[f32](1.0, 2.0)) + c1.itemset(1, ComplexSIMD[f32](3.0, 4.0)) + c1.itemset(2, ComplexSIMD[f32](5.0, 6.0)) + c1.itemset(3, ComplexSIMD[f32](7.0, 8.0)) assert_almost_equal(c1.item(0).re, 1.0, "init failed") assert_almost_equal(c1.item(0).im, 2.0, "init failed") fn test_complex_array_add() raises: """Test addition of ComplexArray numbers.""" - var c1 = ComplexNDArray[cf32](Shape(2, 2)) - var c2 = ComplexNDArray[cf32](Shape(2, 2)) - c1.itemset(0, ComplexSIMD[cf32](1.0, 2.0)) - c1.itemset(1, ComplexSIMD[cf32](3.0, 4.0)) - c1.itemset(2, ComplexSIMD[cf32](5.0, 6.0)) - c1.itemset(3, ComplexSIMD[cf32](7.0, 8.0)) - c2.itemset(0, ComplexSIMD[cf32](1.0, 2.0)) - c2.itemset(1, ComplexSIMD[cf32](3.0, 4.0)) - c2.itemset(2, ComplexSIMD[cf32](5.0, 6.0)) - c2.itemset(3, ComplexSIMD[cf32](7.0, 8.0)) + var c1 = ComplexNDArray[f32](Shape(2, 2)) + var c2 = ComplexNDArray[f32](Shape(2, 2)) + c1.itemset(0, ComplexSIMD[f32](1.0, 2.0)) + c1.itemset(1, ComplexSIMD[f32](3.0, 4.0)) + c1.itemset(2, ComplexSIMD[f32](5.0, 6.0)) + c1.itemset(3, ComplexSIMD[f32](7.0, 8.0)) + c2.itemset(0, ComplexSIMD[f32](1.0, 2.0)) + c2.itemset(1, ComplexSIMD[f32](3.0, 4.0)) + c2.itemset(2, ComplexSIMD[f32](5.0, 6.0)) + c2.itemset(3, ComplexSIMD[f32](7.0, 8.0)) var sum = c1 + c2 @@ -42,17 +42,17 @@ fn test_complex_array_add() raises: fn test_complex_array_sub() raises: """Test subtraction of ComplexArray numbers.""" - var c1 = ComplexNDArray[cf32](Shape(2, 2)) - var c2 = ComplexNDArray[cf32](Shape(2, 2)) - c1.itemset(0, ComplexSIMD[cf32](1.0, 2.0)) - c1.itemset(1, ComplexSIMD[cf32](3.0, 4.0)) - c1.itemset(2, ComplexSIMD[cf32](5.0, 6.0)) - c1.itemset(3, ComplexSIMD[cf32](7.0, 8.0)) - - c2.itemset(0, ComplexSIMD[cf32](3.0, 4.0)) - c2.itemset(1, ComplexSIMD[cf32](5.0, 6.0)) - c2.itemset(2, ComplexSIMD[cf32](7.0, 8.0)) - c2.itemset(3, ComplexSIMD[cf32](9.0, 10.0)) + var c1 = ComplexNDArray[f32](Shape(2, 2)) + var c2 = ComplexNDArray[f32](Shape(2, 2)) + c1.itemset(0, ComplexSIMD[f32](1.0, 2.0)) + c1.itemset(1, ComplexSIMD[f32](3.0, 4.0)) + c1.itemset(2, ComplexSIMD[f32](5.0, 6.0)) + c1.itemset(3, ComplexSIMD[f32](7.0, 8.0)) + + c2.itemset(0, ComplexSIMD[f32](3.0, 4.0)) + c2.itemset(1, ComplexSIMD[f32](5.0, 6.0)) + c2.itemset(2, ComplexSIMD[f32](7.0, 8.0)) + c2.itemset(3, ComplexSIMD[f32](9.0, 10.0)) var diff = c1 - c2 @@ -68,17 +68,17 @@ fn test_complex_array_sub() raises: fn test_complex_array_mul() raises: """Test multiplication of ComplexArray numbers.""" - var c1 = ComplexNDArray[cf32](Shape(2, 2)) - var c2 = ComplexNDArray[cf32](Shape(2, 2)) - c1.itemset(0, ComplexSIMD[cf32](1.0, 2.0)) - c1.itemset(1, ComplexSIMD[cf32](3.0, 4.0)) - c1.itemset(2, ComplexSIMD[cf32](5.0, 6.0)) - c1.itemset(3, ComplexSIMD[cf32](7.0, 8.0)) - - c2.itemset(0, ComplexSIMD[cf32](1.0, 2.0)) - c2.itemset(1, ComplexSIMD[cf32](3.0, 4.0)) - c2.itemset(2, ComplexSIMD[cf32](5.0, 6.0)) - c2.itemset(3, ComplexSIMD[cf32](7.0, 8.0)) + var c1 = ComplexNDArray[f32](Shape(2, 2)) + var c2 = ComplexNDArray[f32](Shape(2, 2)) + c1.itemset(0, ComplexSIMD[f32](1.0, 2.0)) + c1.itemset(1, ComplexSIMD[f32](3.0, 4.0)) + c1.itemset(2, ComplexSIMD[f32](5.0, 6.0)) + c1.itemset(3, ComplexSIMD[f32](7.0, 8.0)) + + c2.itemset(0, ComplexSIMD[f32](1.0, 2.0)) + c2.itemset(1, ComplexSIMD[f32](3.0, 4.0)) + c2.itemset(2, ComplexSIMD[f32](5.0, 6.0)) + c2.itemset(3, ComplexSIMD[f32](7.0, 8.0)) var prod = c1 * c2 @@ -88,17 +88,17 @@ fn test_complex_array_mul() raises: fn test_complex_array_div() raises: """Test division of ComplexArray numbers.""" - var c1 = ComplexNDArray[cf32](Shape(2, 2)) - var c2 = ComplexNDArray[cf32](Shape(2, 2)) - c1.itemset(0, ComplexSIMD[cf32](1.0, 2.0)) - c1.itemset(1, ComplexSIMD[cf32](3.0, 4.0)) - c1.itemset(2, ComplexSIMD[cf32](5.0, 6.0)) - c1.itemset(3, ComplexSIMD[cf32](7.0, 8.0)) - - c2.itemset(0, ComplexSIMD[cf32](3.0, 4.0)) - c2.itemset(1, ComplexSIMD[cf32](5.0, 6.0)) - c2.itemset(2, ComplexSIMD[cf32](7.0, 8.0)) - c2.itemset(3, ComplexSIMD[cf32](9.0, 10.0)) + var c1 = ComplexNDArray[f32](Shape(2, 2)) + var c2 = ComplexNDArray[f32](Shape(2, 2)) + c1.itemset(0, ComplexSIMD[f32](1.0, 2.0)) + c1.itemset(1, ComplexSIMD[f32](3.0, 4.0)) + c1.itemset(2, ComplexSIMD[f32](5.0, 6.0)) + c1.itemset(3, ComplexSIMD[f32](7.0, 8.0)) + + c2.itemset(0, ComplexSIMD[f32](3.0, 4.0)) + c2.itemset(1, ComplexSIMD[f32](5.0, 6.0)) + c2.itemset(2, ComplexSIMD[f32](7.0, 8.0)) + c2.itemset(3, ComplexSIMD[f32](9.0, 10.0)) var quot = c1 / c2 diff --git a/tests/core/test_complexSIMD.mojo b/tests/core/test_complexSIMD.mojo index fa0a9ec0..02587ba0 100644 --- a/tests/core/test_complexSIMD.mojo +++ b/tests/core/test_complexSIMD.mojo @@ -4,19 +4,19 @@ from numojo import * fn test_complex_init() raises: """Test initialization of ComplexSIMD.""" - var c1 = ComplexSIMD[cf32](1.0, 2.0) + var c1 = ComplexSIMD[f32](1.0, 2.0) assert_equal(c1.re, 1.0, "init failed") assert_equal(c1.im, 2.0, "init failed") - var c2 = ComplexSIMD[cf32](c1) + var c2 = ComplexSIMD[f32](c1) assert_equal(c2.re, c1.re) assert_equal(c2.im, c1.im) fn test_complex_add() raises: """Test addition of ComplexSIMD numbers.""" - var c1 = ComplexSIMD[cf32](1.0, 2.0) - var c2 = ComplexSIMD[cf32](3.0, 4.0) + var c1 = ComplexSIMD[f32](1.0, 2.0) + var c2 = ComplexSIMD[f32](3.0, 4.0) var sum = c1 + c2 assert_equal(sum.re, 4.0, "addition failed") @@ -30,8 +30,8 @@ fn test_complex_add() raises: fn test_complex_sub() raises: """Test subtraction of ComplexSIMD numbers.""" - var c1 = ComplexSIMD[cf32](3.0, 4.0) - var c2 = ComplexSIMD[cf32](1.0, 2.0) + var c1 = ComplexSIMD[f32](3.0, 4.0) + var c2 = ComplexSIMD[f32](1.0, 2.0) var diff = c1 - c2 assert_equal(diff.re, 2.0, "subtraction failed") @@ -45,8 +45,8 @@ fn test_complex_sub() raises: fn test_complex_mul() raises: """Test multiplication of ComplexSIMD numbers.""" - var c1 = ComplexSIMD[cf32](1.0, 2.0) - var c2 = ComplexSIMD[cf32](3.0, 4.0) + var c1 = ComplexSIMD[f32](1.0, 2.0) + var c2 = ComplexSIMD[f32](3.0, 4.0) # (1 + 2i)(3 + 4i) = (1*3 - 2*4) + (1*4 + 2*3)i = -5 + 10i var prod = c1 * c2 @@ -61,8 +61,8 @@ fn test_complex_mul() raises: fn test_complex_div() raises: """Test division of ComplexSIMD numbers.""" - var c1 = ComplexSIMD[cf32](1.0, 2.0) - var c2 = ComplexSIMD[cf32](3.0, 4.0) + var c1 = ComplexSIMD[f32](1.0, 2.0) + var c2 = ComplexSIMD[f32](3.0, 4.0) # (1 + 2i)/(3 + 4i) = (1*3 + 2*4 + (2*3 - 1*4)i)/(3^2 + 4^2) # = (11 + 2i)/25 From f615bda0e1ce127b1168b9a346b488cb3b22479d Mon Sep 17 00:00:00 2001 From: shivasankar Date: Tue, 15 Apr 2025 17:55:17 +0900 Subject: [PATCH 3/4] fixed minor typo in matrix --- numojo/core/matrix.mojo | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/numojo/core/matrix.mojo b/numojo/core/matrix.mojo index 64e88301..39cba632 100644 --- a/numojo/core/matrix.mojo +++ b/numojo/core/matrix.mojo @@ -1376,14 +1376,14 @@ struct Matrix[dtype: DType = DType.float64]( """ if (shape[0] == 0) and (shape[1] == 0): - var M = Matrix[dtype](shape=(1, object.size)) + var M = Matrix[dtype](shape=(1, len(object))) memcpy(M._buf.ptr, object.data, M.size) return M^ - if shape[0] * shape[1] != object.size: + if shape[0] * shape[1] != len(object): var message = String( "The input has {} elements, but the target has the shape {}x{}" - ).format(object.size, shape[0], shape[1]) + ).format(len(object), shape[0], shape[1]) raise Error(message) var M = Matrix[dtype](shape=shape) memcpy(M._buf.ptr, object.data, M.size) From d569fa0c0548531d4382125aaa00647401203bce Mon Sep 17 00:00:00 2001 From: shivasankar Date: Tue, 15 Apr 2025 17:59:59 +0900 Subject: [PATCH 4/4] formatted --- numojo/core/complex/complex_ndarray.mojo | 74 ++++++--------- numojo/core/complex/complex_simd.mojo | 26 ++++-- numojo/routines/creation.mojo | 114 +++++++---------------- numojo/routines/io/formatting.mojo | 5 +- 4 files changed, 81 insertions(+), 138 deletions(-) diff --git a/numojo/core/complex/complex_ndarray.mojo b/numojo/core/complex/complex_ndarray.mojo index 494324cf..7b17a593 100644 --- a/numojo/core/complex/complex_ndarray.mojo +++ b/numojo/core/complex/complex_ndarray.mojo @@ -87,9 +87,9 @@ from numojo.routines.statistics.averages import mean # ===----------------------------------------------------------------------===# # TODO: Add SIMD width as a parameter. @value -struct ComplexNDArray[ - dtype: DType = DType.float64 -](Stringable, Representable, CollectionElement, Sized, Writable): +struct ComplexNDArray[dtype: DType = DType.float64]( + Stringable, Representable, CollectionElement, Sized, Writable +): """ Represents a Complex N-Dimensional Array. @@ -116,7 +116,9 @@ struct ComplexNDArray[ """LIFETIME METHODS""" @always_inline("nodebug") - fn __init__(out self, owned re: NDArray[Self.dtype], owned im: NDArray[Self.dtype]): + fn __init__( + out self, owned re: NDArray[Self.dtype], owned im: NDArray[Self.dtype] + ): self._re = re self._im = im self.ndim = re.ndim @@ -402,9 +404,7 @@ struct ComplexNDArray[ im=self._im._buf.ptr[], ) - fn __getitem__( - self, index: Item - ) raises -> ComplexSIMD[Self.dtype]: + fn __getitem__(self, index: Item) raises -> ComplexSIMD[Self.dtype]: """ Get the value at the index list. @@ -729,9 +729,9 @@ struct ComplexNDArray[ # Get the shape of resulted array var shape = indices.shape.join(self.shape._pop(0)) - var result: ComplexNDArray[Self.dtype] = ComplexNDArray[ - Self.dtype - ](shape) + var result: ComplexNDArray[Self.dtype] = ComplexNDArray[Self.dtype]( + shape + ) var size_per_item = self.size // self.shape[0] # Fill in the values @@ -1064,9 +1064,7 @@ struct ComplexNDArray[ im=self._im._buf.ptr[index], ) - fn load[ - width: Int = 1 - ](self, index: Int) raises -> ComplexSIMD[Self.dtype]: + fn load[width: Int = 1](self, index: Int) raises -> ComplexSIMD[Self.dtype]: """ Safely loads a ComplexSIMD element of size `width` at `index` from the underlying buffer. @@ -1099,9 +1097,7 @@ struct ComplexNDArray[ fn load[ width: Int = 1 - ](self, *indices: Int) raises -> ComplexSIMD[ - Self.dtype, width=width - ]: + ](self, *indices: Int) raises -> ComplexSIMD[Self.dtype, width=width]: """ Safely loads a ComplexSIMD element of size `width` at given variadic indices from the underlying buffer. @@ -1325,9 +1321,7 @@ struct ComplexNDArray[ val._im, self._im, nshape, ncoefficients, nstrides, noffset, index ) - fn __setitem__( - mut self, index: Item, val: ComplexSIMD[Self.dtype] - ) raises: + fn __setitem__(mut self, index: Item, val: ComplexSIMD[Self.dtype]) raises: """ Set the value at the index list. """ @@ -1652,9 +1646,7 @@ struct ComplexNDArray[ var imag: NDArray[dtype] = math.add[dtype](self._im, other) return Self(real, imag) - fn __radd__( - mut self, other: ComplexSIMD[Self.dtype] - ) raises -> Self: + fn __radd__(mut self, other: ComplexSIMD[Self.dtype]) raises -> Self: """ Enables `ComplexSIMD + ComplexNDArray`. """ @@ -1746,9 +1738,7 @@ struct ComplexNDArray[ var imag: NDArray[dtype] = math.sub[dtype](self._im, other) return Self(real, imag) - fn __rsub__( - mut self, other: ComplexSIMD[Self.dtype] - ) raises -> Self: + fn __rsub__(mut self, other: ComplexSIMD[Self.dtype]) raises -> Self: """ Enables `ComplexSIMD - ComplexNDArray`. """ @@ -1895,9 +1885,7 @@ struct ComplexNDArray[ self._re *= other self._im *= other - fn __truediv__( - self, other: ComplexSIMD[Self.dtype] - ) raises -> Self: + fn __truediv__(self, other: ComplexSIMD[Self.dtype]) raises -> Self: """ Enables `ComplexNDArray / ComplexSIMD`. """ @@ -1913,9 +1901,7 @@ struct ComplexNDArray[ var imag: NDArray[dtype] = math.div[dtype](self._im, other) return Self(real, imag) - fn __truediv__( - self, other: ComplexNDArray[Self.dtype] - ) raises -> Self: + fn __truediv__(self, other: ComplexNDArray[Self.dtype]) raises -> Self: """ Enables `ComplexNDArray / ComplexNDArray`. """ @@ -1933,9 +1919,7 @@ struct ComplexNDArray[ var imag: NDArray[dtype] = math.div[dtype](self._im, other) return Self(real, imag) - fn __rtruediv__( - mut self, other: ComplexSIMD[Self.dtype] - ) raises -> Self: + fn __rtruediv__(mut self, other: ComplexSIMD[Self.dtype]) raises -> Self: """ Enables `ComplexSIMD / ComplexNDArray`. """ @@ -2259,7 +2243,9 @@ struct ComplexNDArray[ fn __reversed__( self, - ) raises -> _ComplexNDArrayIter[__origin_of(self._re), Self.dtype, forward=False]: + ) raises -> _ComplexNDArrayIter[ + __origin_of(self._re), Self.dtype, forward=False + ]: """ Iterates backwards over elements of the ComplexNDArray, returning copied value. @@ -2268,12 +2254,13 @@ struct ComplexNDArray[ A reversed iterator of NDArray elements. """ - return _ComplexNDArrayIter[__origin_of(self._re), Self.dtype, forward=False]( + return _ComplexNDArrayIter[ + __origin_of(self._re), Self.dtype, forward=False + ]( self, dimension=0, ) - fn itemset( mut self, index: Variant[Int, List[Int]], @@ -2386,7 +2373,9 @@ struct _ComplexNDArrayIter[ var ndim: Int var size_of_item: Int - fn __init__(out self, read a: ComplexNDArray[dtype], read dimension: Int) raises: + fn __init__( + out self, read a: ComplexNDArray[dtype], read dimension: Int + ) raises: """ Initialize the iterator. @@ -2502,8 +2491,7 @@ struct _ComplexNDArrayIter[ return res else: # 0-D array - var res = numojo.creation._0darray[dtype](ComplexSIMD[dtype]( - self.re_ptr[index], - self.im_ptr[index] - )) - return res \ No newline at end of file + var res = numojo.creation._0darray[dtype]( + ComplexSIMD[dtype](self.re_ptr[index], self.im_ptr[index]) + ) + return res diff --git a/numojo/core/complex/complex_simd.mojo b/numojo/core/complex/complex_simd.mojo index ca5040bb..d2a8d744 100644 --- a/numojo/core/complex/complex_simd.mojo +++ b/numojo/core/complex/complex_simd.mojo @@ -2,13 +2,13 @@ from math import sqrt alias ComplexScalar = ComplexSIMD[_, width=1] + @register_passable("trivial") -struct ComplexSIMD[ - dtype: DType, width: Int = 1 -](): +struct ComplexSIMD[dtype: DType, width: Int = 1](): """ Represents a Complex number SIMD type with real and imaginary parts. """ + # FIELDS """The underlying data real and imaginary parts of the complex number.""" var re: SIMD[dtype, width] @@ -25,7 +25,11 @@ struct ComplexSIMD[ self = other @always_inline - fn __init__(out self, re: SIMD[Self.dtype, Self.width], im: SIMD[Self.dtype, Self.width]): + fn __init__( + out self, + re: SIMD[Self.dtype, Self.width], + im: SIMD[Self.dtype, Self.width], + ): """ Initializes a ComplexSIMD instance with specified real and imaginary parts. @@ -256,7 +260,9 @@ struct ComplexSIMD[ Returns: String: The string representation of the ComplexSIMD instance. """ - return "ComplexSIMD[{}]({}, {})".format(String(Self.dtype), self.re, self.im) + return "ComplexSIMD[{}]({}, {})".format( + String(Self.dtype), self.re, self.im + ) fn __getitem__(self, idx: Int) raises -> SIMD[Self.dtype, Self.width]: """ @@ -276,7 +282,9 @@ struct ComplexSIMD[ else: raise Error("Index out of range") - fn __setitem__(mut self, idx: Int, value: SIMD[Self.dtype, Self.width]) raises: + fn __setitem__( + mut self, idx: Int, value: SIMD[Self.dtype, Self.width] + ) raises: """ Sets the real and imaginary parts of the ComplexSIMD instance. @@ -314,9 +322,7 @@ struct ComplexSIMD[ """ return self[idx] - fn itemset( - mut self, val: ComplexSIMD[Self.dtype, Self.width] - ): + fn itemset(mut self, val: ComplexSIMD[Self.dtype, Self.width]): """ Sets the real and imaginary parts of the ComplexSIMD instance. @@ -355,4 +361,4 @@ struct ComplexSIMD[ """ Returns the imaginary part of the ComplexSIMD instance. """ - return self.im \ No newline at end of file + return self.im diff --git a/numojo/routines/creation.mojo b/numojo/routines/creation.mojo index e49a7db0..58d18f0b 100644 --- a/numojo/routines/creation.mojo +++ b/numojo/routines/creation.mojo @@ -104,9 +104,7 @@ fn arange[ ]( start: ComplexSIMD[dtype], stop: ComplexSIMD[dtype], - step: ComplexSIMD[dtype] = ComplexSIMD[dtype]( - 1, 1 - ), + step: ComplexSIMD[dtype] = ComplexSIMD[dtype](1, 1), ) raises -> ComplexNDArray[dtype]: """ Function that computes a series of values starting from "start" to "stop" @@ -135,9 +133,7 @@ fn arange[ num_re, num_im ) ) - var result: ComplexNDArray[dtype] = ComplexNDArray[ - dtype - ](Shape(num_re)) + var result: ComplexNDArray[dtype] = ComplexNDArray[dtype](Shape(num_re)) for idx in range(num_re): result.store[width=1]( idx, @@ -151,9 +147,7 @@ fn arange[ fn arange[ dtype: DType = DType.float64, -](stop: ComplexSIMD[dtype]) raises -> ComplexNDArray[ - dtype -]: +](stop: ComplexSIMD[dtype]) raises -> ComplexNDArray[dtype]: """ (Overload) When start is 0 and step is 1. """ @@ -167,15 +161,11 @@ fn arange[ ) ) - var result: ComplexNDArray[dtype] = ComplexNDArray[ - dtype - ](Shape(size_re)) + var result: ComplexNDArray[dtype] = ComplexNDArray[dtype](Shape(size_re)) for i in range(size_re): result.store[width=1]( i, - ComplexSIMD[dtype]( - Scalar[dtype](i), Scalar[dtype](i) - ), + ComplexSIMD[dtype](Scalar[dtype](i), Scalar[dtype](i)), ) return result^ @@ -334,9 +324,7 @@ fn linspace[ """ constrained[not dtype.is_integral()]() if parallel: - return _linspace_parallel[dtype]( - start, stop, num, endpoint - ) + return _linspace_parallel[dtype](start, stop, num, endpoint) else: return _linspace_serial[dtype](start, stop, num, endpoint) @@ -364,9 +352,7 @@ fn _linspace_serial[ Returns: A ComplexNDArray of `dtype` with `num` linearly spaced elements between `start` and `stop`. """ - var result: ComplexNDArray[dtype] = ComplexNDArray[ - dtype - ](Shape(num)) + var result: ComplexNDArray[dtype] = ComplexNDArray[dtype](Shape(num)) if endpoint: var step_re: Scalar[dtype] = (stop.re - start.re) / (num - 1) @@ -416,9 +402,7 @@ fn _linspace_parallel[ Returns: A ComplexNDArray of `dtype` with `num` linearly spaced elements between `start` and `stop`. """ - var result: ComplexNDArray[dtype] = ComplexNDArray[ - dtype - ](Shape(num)) + var result: ComplexNDArray[dtype] = ComplexNDArray[dtype](Shape(num)) alias nelts = simdwidthof[dtype]() if endpoint: @@ -607,9 +591,7 @@ fn logspace[ stop: ComplexSIMD[dtype], num: Int, endpoint: Bool = True, - base: ComplexSIMD[dtype] = ComplexSIMD[dtype]( - 10.0, 10.0 - ), + base: ComplexSIMD[dtype] = ComplexSIMD[dtype](10.0, 10.0), parallel: Bool = False, ) raises -> ComplexNDArray[dtype]: """ @@ -676,9 +658,7 @@ fn _logspace_serial[ Returns: A ComplexNDArray of `dtype` with `num` logarithmic spaced elements between `start` and `stop`. """ - var result: ComplexNDArray[dtype] = ComplexNDArray[ - dtype - ](NDArrayShape(num)) + var result: ComplexNDArray[dtype] = ComplexNDArray[dtype](NDArrayShape(num)) if endpoint: var step_re: Scalar[dtype] = (stop.re - start.re) / (num - 1) @@ -730,9 +710,7 @@ fn _logspace_parallel[ Returns: A ComplexNDArray of `dtype` with `num` logarithmic spaced elements between `start` and `stop`. """ - var result: ComplexNDArray[dtype] = ComplexNDArray[ - dtype - ](NDArrayShape(num)) + var result: ComplexNDArray[dtype] = ComplexNDArray[dtype](NDArrayShape(num)) if endpoint: var step_re: Scalar[dtype] = (stop.re - start.re) / (num - 1) @@ -858,34 +836,30 @@ fn geomspace[ var a: ComplexSIMD[dtype] = start if endpoint: - var result: ComplexNDArray[dtype] = ComplexNDArray[ - dtype - ](NDArrayShape(num)) + var result: ComplexNDArray[dtype] = ComplexNDArray[dtype]( + NDArrayShape(num) + ) var base: ComplexSIMD[dtype] = (stop / start) var power: Scalar[dtype] = 1 / Scalar[dtype](num - 1) var r: ComplexSIMD[dtype] = base**power for i in range(num): result.store[1]( i, - ComplexSIMD[dtype]( - a.re * r.re**i, a.im * r.im**i - ), + ComplexSIMD[dtype](a.re * r.re**i, a.im * r.im**i), ) return result^ else: - var result: ComplexNDArray[dtype] = ComplexNDArray[ - dtype - ](NDArrayShape(num)) + var result: ComplexNDArray[dtype] = ComplexNDArray[dtype]( + NDArrayShape(num) + ) var base: ComplexSIMD[dtype] = (stop / start) var power: Scalar[dtype] = 1 / Scalar[dtype](num) var r: ComplexSIMD[dtype] = base**power for i in range(num): result.store[1]( i, - ComplexSIMD[dtype]( - a.re * r.re**i, a.im * r.im**i - ), + ComplexSIMD[dtype](a.re * r.re**i, a.im * r.im**i), ) return result^ @@ -978,9 +952,7 @@ fn empty_like[ fn empty_like[ dtype: DType = DType.float64, -](array: ComplexNDArray[dtype]) raises -> ComplexNDArray[ - dtype -]: +](array: ComplexNDArray[dtype]) raises -> ComplexNDArray[dtype]: """ Generate an empty ComplexNDArray of the same shape as `array`. @@ -1183,9 +1155,7 @@ fn ones_like[ fn ones_like[ dtype: DType = DType.float64, -](array: ComplexNDArray[dtype]) raises -> ComplexNDArray[ - dtype -]: +](array: ComplexNDArray[dtype]) raises -> ComplexNDArray[dtype]: """ Generate a ComplexNDArray of the same shape as `a` filled with ones. @@ -1295,9 +1265,7 @@ fn zeros_like[ fn zeros_like[ dtype: DType = DType.float64, -](array: ComplexNDArray[dtype]) raises -> ComplexNDArray[ - dtype -]: +](array: ComplexNDArray[dtype]) raises -> ComplexNDArray[dtype]: """ Generate a ComplexNDArray of the same shape as `a` filled with zeros. @@ -1310,9 +1278,7 @@ fn zeros_like[ Returns: A ComplexNDArray of `dtype` with the same shape as `a` filled with zeros. """ - return full[dtype]( - shape=array.shape, fill_value=ComplexSIMD[dtype](0, 0) - ) + return full[dtype](shape=array.shape, fill_value=ComplexSIMD[dtype](0, 0)) fn full[ @@ -1383,7 +1349,8 @@ fn full_like[ A NDArray of `dtype` with the same shape as `a` filled with `fill_value`. """ return full[dtype](shape=array.shape, fill_value=fill_value, order=order) - + + fn full[ dtype: DType = DType.float64 ]( @@ -1462,9 +1429,7 @@ fn full_like[ Returns: A ComplexNDArray of `dtype` with the same shape as `a` filled with `fill_value`. """ - return full[dtype]( - shape=array.shape, fill_value=fill_value, order=order - ) + return full[dtype](shape=array.shape, fill_value=fill_value, order=order) # ===------------------------------------------------------------------------===# @@ -1520,9 +1485,7 @@ fn diag[ fn diag[ dtype: DType = DType.float64, -](v: ComplexNDArray[dtype], k: Int = 0) raises -> ComplexNDArray[ - dtype -]: +](v: ComplexNDArray[dtype], k: Int = 0) raises -> ComplexNDArray[dtype]: """ Extract a diagonal or construct a diagonal ComplexNDArray. @@ -1576,9 +1539,7 @@ fn diagflat[ fn diagflat[ dtype: DType = DType.float64, -](v: ComplexNDArray[dtype], k: Int = 0) raises -> ComplexNDArray[ - dtype -]: +](v: ComplexNDArray[dtype], k: Int = 0) raises -> ComplexNDArray[dtype]: """ Generate a 2-D ComplexNDArray with the flattened input as the diagonal. @@ -1691,9 +1652,7 @@ fn tril[ fn tril[ dtype: DType = DType.float64, -](m: ComplexNDArray[dtype], k: Int = 0) raises -> ComplexNDArray[ - dtype -]: +](m: ComplexNDArray[dtype], k: Int = 0) raises -> ComplexNDArray[dtype]: """ Zero out elements above the k-th diagonal. @@ -1757,9 +1716,7 @@ fn triu[ fn triu[ dtype: DType = DType.float64, -](m: ComplexNDArray[dtype], k: Int = 0) raises -> ComplexNDArray[ - dtype -]: +](m: ComplexNDArray[dtype], k: Int = 0) raises -> ComplexNDArray[dtype]: """ Zero out elements below the k-th diagonal. @@ -1913,9 +1870,7 @@ fn astype[ fn astype[ dtype: DType, //, target: DType, -](a: ComplexNDArray[dtype]) raises -> ComplexNDArray[ - target -]: +](a: ComplexNDArray[dtype]) raises -> ComplexNDArray[target]: """ Cast a ComplexNDArray to a different dtype. @@ -1935,6 +1890,7 @@ fn astype[ im=astype[target](a._im), ) + # ===------------------------------------------------------------------------===# # Construct array from other objects # ===------------------------------------------------------------------------===# @@ -2272,11 +2228,7 @@ fn _0darray[ fn _0darray[ dtype: DType, -]( - val: ComplexSIMD[dtype], -) raises -> ComplexNDArray[ - dtype -]: +](val: ComplexSIMD[dtype],) raises -> ComplexNDArray[dtype]: """ Initialize an special 0d complexarray (numojo scalar). The ndim is 0. diff --git a/numojo/routines/io/formatting.mojo b/numojo/routines/io/formatting.mojo index 3138093b..b6300e6f 100644 --- a/numojo/routines/io/formatting.mojo +++ b/numojo/routines/io/formatting.mojo @@ -421,10 +421,7 @@ fn format_value[ fn format_value[ dtype: DType -]( - value: ComplexSIMD[dtype], - print_options: PrintOptions, -) raises -> String: +](value: ComplexSIMD[dtype], print_options: PrintOptions,) raises -> String: """ Format a complex value based on the print options.