Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ magic.lock
mojo
numojo.mojopkg
bench.mojo
test_ndarray.ipynb
test.mojo
test*.mojo
test*.ipynb
tempCodeRunnerFile.mojo

# Auto docs
Expand Down
154 changes: 95 additions & 59 deletions numojo/core/ndarray.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -2928,69 +2928,62 @@ struct NDArray[dtype: DType = DType.float64](
var padding = print_options.padding
var edge_items = print_options.edge_items

# The following code get the max value and the min value
# to determine the digits before decimals and the negative sign
# and then determine the formatted withd
var negative_sign: Bool = False # whether there should be a negative sign
var number_of_digits: Int # number of digits before or after decimal point
var number_of_digits_small_values: Int # number of digits after decimal point for small values
var formatted_width: Int # formatted width based on precision and digits before decimal points
var max_value: Scalar[dtype] = abs(
self._buf.ptr[]
) # maximum absolute value of the items
var min_value: Scalar[dtype] = abs(
self._buf.ptr[]
) # minimum absolute value of the items
var val: Scalar[dtype] # storage of value of the item

var skip: Bool
for index in range(self.size):
skip = False
var remainder = index
var indices = Item(ndim=self.ndim, initialized=False)
for i in range(self.ndim):
indices[i], remainder = divmod(
remainder, NDArrayStrides(self.shape)[i]
)
if (indices[i] >= 3) and (indices[i] < self.shape[i] - 3):
skip = True
continue
if skip:
continue

val = self._buf.ptr[_get_offset(indices, self.strides)]
if val < 0:
negative_sign = True
max_value = max(
# The following code get the max value and the min value of
# the pritable region to determine the digits before decimals and
# the negative sign and then determine the formatted width.
if dimension == 0:
var negative_sign: Bool = False # whether there should be a negative sign
var number_of_digits: Int # number of digits before or after decimal point
var number_of_digits_small_values: Int # number of digits after decimal point for small values
var formatted_width: Int # formatted width based on precision and digits before decimal points
var max_value: Scalar[dtype] = abs(
self._buf.ptr[]
) # maximum absolute value of the items
var min_value: Scalar[dtype] = abs(
self._buf.ptr[]
) # minimum absolute value of the items
var indices = Item(
ndim=self.ndim, initialized=True
) # Temporarily store the indices

self._find_max_and_min_in_printable_region(
self.shape,
self.strides,
edge_items,
indices,
negative_sign,
max_value,
abs(val),
)
min_value = min(
min_value,
abs(val),
0,
)
number_of_digits = Int(log10(Float64(max_value))) + 1
number_of_digits_small_values = abs(Int(log10(Float64(min_value)))) + 1

if dtype.is_floating_point():
formatted_width = (
print_options.precision
+ 1
+ number_of_digits
+ Int(negative_sign)

number_of_digits = Int(log10(Float64(max_value))) + 1
number_of_digits_small_values = (
abs(Int(log10(Float64(min_value)))) + 1
)
else:
formatted_width = number_of_digits + Int(negative_sign)

# If the number is not too wide,
# or digits after decimal point is not many
# format it as a floating point.
if (formatted_width <= 14) and (number_of_digits_small_values <= 2):
print_options.formatted_width = formatted_width
# Otherwise, format it as a scientific number.
else:
print_options.float_format = "scientific"
print_options.formatted_width = 7 + print_options.precision

if dtype.is_floating_point():
formatted_width = (
print_options.precision
+ 1
+ number_of_digits
+ Int(negative_sign)
)
# If the number is not too wide,
# or digits after decimal point is not many
# format it as a floating point.
if (formatted_width <= 14) and (
number_of_digits_small_values <= 2
):
print_options.formatted_width = formatted_width
# Otherwise, format it as a scientific number.
else:
print_options.float_format = "scientific"
print_options.formatted_width = 7 + print_options.precision
else: # type is integral
print_options.formatted_width = number_of_digits + Int(
negative_sign
)

if dimension == self.ndim - 1:
var result: String = String("[") + padding
Expand Down Expand Up @@ -3085,6 +3078,49 @@ struct NDArray[dtype: DType = DType.float64](
result = result + "]"
return result

fn _find_max_and_min_in_printable_region(
self,
shape: NDArrayShape,
strides: NDArrayStrides,
edge_items: Int,
mut indices: Item,
mut negative_sign: Bool, # whether there should be a negative sign
mut max_value: Scalar[dtype], # maximum absolute value of the items
mut min_value: Scalar[dtype], # minimum absolute value of the items
current_axis: Int = 0,
) raises:
"""
Travel through the printable region of the array to find maximum and minimum values.
"""
var offsets = List[Int]()
if shape[current_axis] > edge_items * 2:
for i in range(0, edge_items):
offsets.append(i)
offsets.append(shape[current_axis] - 1 - i)
else:
for i in range(0, shape[current_axis]):
offsets.append(i)

for index_at_axis in offsets:
indices._buf[current_axis] = index_at_axis[]
if current_axis == shape.ndim - 1:
var val = (self._buf.ptr + _get_offset(indices, strides))[]
if val < 0:
negative_sign = True
max_value = max(max_value, abs(val))
min_value = min(min_value, abs(val))
else:
self._find_max_and_min_in_printable_region(
shape,
strides,
edge_items,
indices,
negative_sign,
max_value,
min_value,
current_axis + 1,
)

# ===-------------------------------------------------------------------===#
# OTHER METHODS
# (Sorted alphabetically)
Expand Down
18 changes: 18 additions & 0 deletions numojo/routines/io/formatting.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,36 @@ alias printoptions = PrintOptions
@value
struct PrintOptions:
var precision: Int
"""
The number of decimal places to include in the formatted string.
Defaults to 4.
"""
var suppress_small: Bool
var separator: String
"""
The separator between elements in the array. Defaults to a space.
"""
var padding: String
"""
The padding symbol between the elements at the edge and the brackets.
Defaults to an empty string.
"""
var threshold: Int
var line_width: Int
var edge_items: Int
"""
The number of items to display at the beginning and end of a dimension.
Defaults to 3.
"""
var sign: Bool
var float_format: String
var complex_format: String
var nan_string: String
var inf_string: String
var formatted_width: Int
"""
The width of the formatted string per element of array.
"""
var exponent_threshold: Int
var suppress_scientific: Bool

Expand Down