Skip to content

Commit f05de8c

Browse files
authored
[manipulation routines] Implement broadcast_to() for NDArray (#202)
Implements `broadcast_to()` for `NDArray`. Add tests. It can broadcast an ndarray of any shape to any compatible shape. The data will be copied into the new array. An example goes as follows. ```mojo from numojo.prelude import * from python import Python fn main() raises: var np = Python.import_module("numpy") var a = nm.random.rand(Shape(2, 3)) print(a) print(nm.routines.manipulation.broadcast_to(a, Shape(2, 2, 3))) print(np.broadcast_to(a.to_numpy(), (2, 2, 3))) ``` ```console [[0.8073 0.5361 0.4442] [0.9378 0.1910 0.2421]] 2D-array Shape(2,3) Strides(3,1) DType: f64 C-cont: True F-cont: False own data: True [[[0.8073 0.5361 0.4442] [0.9378 0.1910 0.2421]] [[0.8073 0.5361 0.4442] [0.9378 0.1910 0.2421]]] 3D-array Shape(2,2,3) Strides(6,3,1) DType: f64 C-cont: True F-cont: False own data: True [[[0.8074 0.5361 0.4442] [0.9378 0.1911 0.2421]] [[0.8074 0.5361 0.4442] [0.9378 0.1911 0.2421]]] ```
1 parent 75fb2a5 commit f05de8c

File tree

8 files changed

+174
-80
lines changed

8 files changed

+174
-80
lines changed

.github/workflows/run_tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ jobs:
3131

3232
- name: Install magic
3333
run: |
34-
curl -ssL https://magic.modular.com/ff414efd-16ac-4bf3-8efc-50b059272ab6 | bash
34+
curl -ssL https://magic.modular.com/deb181c4-455c-4abe-a263-afcff49ccf67 | bash
3535
3636
- name: Add path
3737
run: |

.github/workflows/test_pre_commit.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626

2727
- name: Install magic
2828
run: |
29-
curl -ssL https://magic.modular.com/ff414efd-16ac-4bf3-8efc-50b059272ab6 | bash
29+
curl -ssL https://magic.modular.com/deb181c4-455c-4abe-a263-afcff49ccf67 | bash
3030
3131
- name: Add path
3232
run: |

numojo/__init__.mojo

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ from numojo.routines.manipulation import (
194194
reshape,
195195
ravel,
196196
transpose,
197+
broadcast_to,
197198
flip,
198199
)
199200

numojo/core/matrix.mojo

Lines changed: 1 addition & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ from python import PythonObject, Python
1717
from numojo.core.ndarray import NDArray
1818
from numojo.core.own_data import OwnData
1919
from numojo.core.utility import _get_offset, _update_flags
20+
from numojo.routines.manipulation import broadcast_to
2021

2122
# ===----------------------------------------------------------------------===#
2223
# Matrix struct
@@ -1633,73 +1634,3 @@ fn _logic_func_matrix_matrix_to_matrix[
16331634
var _B = B
16341635

16351636
return C^
1636-
1637-
1638-
fn broadcast_to[
1639-
dtype: DType
1640-
](A: Matrix[dtype], shape: Tuple[Int, Int]) raises -> Matrix[dtype]:
1641-
"""
1642-
Broadcase the vector to the given shape.
1643-
1644-
Example:
1645-
1646-
```console
1647-
> from numojo import Matrix
1648-
> a = Matrix.fromstring("1 2 3", shape=(1, 3))
1649-
> print(mat.broadcast_to(a, (3, 3)))
1650-
[[1.0 2.0 3.0]
1651-
[1.0 2.0 3.0]
1652-
[1.0 2.0 3.0]]
1653-
> a = Matrix.fromstring("1 2 3", shape=(3, 1))
1654-
> print(mat.broadcast_to(a, (3, 3)))
1655-
[[1.0 1.0 1.0]
1656-
[2.0 2.0 2.0]
1657-
[3.0 3.0 3.0]]
1658-
> a = Matrix.fromstring("1", shape=(1, 1))
1659-
> print(mat.broadcast_to(a, (3, 3)))
1660-
[[1.0 1.0 1.0]
1661-
[1.0 1.0 1.0]
1662-
[1.0 1.0 1.0]]
1663-
> a = Matrix.fromstring("1 2", shape=(1, 2))
1664-
> print(mat.broadcast_to(a, (1, 2)))
1665-
[[1.0 2.0]]
1666-
> a = Matrix.fromstring("1 2 3 4", shape=(2, 2))
1667-
> print(mat.broadcast_to(a, (4, 2)))
1668-
Unhandled exception caught during execution: Cannot broadcast shape 2x2 to shape 4x2!
1669-
```
1670-
"""
1671-
1672-
var B = Matrix[dtype](shape)
1673-
if (A.shape[0] == shape[0]) and (A.shape[1] == shape[1]):
1674-
B = A
1675-
elif (A.shape[0] == 1) and (A.shape[1] == 1):
1676-
B = Matrix.full[dtype](shape, A[0, 0])
1677-
elif (A.shape[0] == 1) and (A.shape[1] == shape[1]):
1678-
for i in range(shape[0]):
1679-
memcpy(
1680-
dest=B._buf.ptr.offset(shape[1] * i),
1681-
src=A._buf.ptr,
1682-
count=shape[1],
1683-
)
1684-
elif (A.shape[1] == 1) and (A.shape[0] == shape[0]):
1685-
for i in range(shape[0]):
1686-
for j in range(shape[1]):
1687-
B._store(i, j, A._buf.ptr[i])
1688-
else:
1689-
var message = String(
1690-
"Cannot broadcast shape {}x{} to shape {}x{}!"
1691-
).format(A.shape[0], A.shape[1], shape[0], shape[1])
1692-
raise Error(message)
1693-
return B^
1694-
1695-
1696-
fn broadcast_to[
1697-
dtype: DType
1698-
](A: Scalar[dtype], shape: Tuple[Int, Int]) raises -> Matrix[dtype]:
1699-
"""
1700-
Broadcase the scalar to the given shape.
1701-
"""
1702-
1703-
var B = Matrix[dtype](shape)
1704-
B = Matrix.full[dtype](shape, A)
1705-
return B^

numojo/core/ndarray.mojo

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ from numojo.routines.manipulation import reshape, ravel
7575
# RawData type is just a wrapper of `UnsafePointer`.
7676
# RefData type has an extra property `indices`: getitem(i) -> A[I[i]].
7777
# TODO: Rename some variables or methods that should not be exposed to users.
78+
# TODO: Remove 0-d array. Raise errors if operations result in 0-d array.
7879
# ===----------------------------------------------------------------------===#
7980

8081

@@ -2323,10 +2324,7 @@ struct NDArray[dtype: DType = DType.float64](
23232324
min_value,
23242325
abs(val),
23252326
)
2326-
number_of_digits = max(
2327-
int(log10(float(max_value))) + 1,
2328-
abs(int(log10(float(min_value)))) + 1,
2329-
)
2327+
number_of_digits = int(log10(float(max_value))) + 1
23302328
number_of_digits_small_values = abs(int(log10(float(min_value)))) + 1
23312329

23322330
if dtype.is_floating_point():

numojo/prelude.mojo

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ from numojo.prelude import *
2020
```
2121
"""
2222

23-
from numojo.core.ndarray import NDArray
24-
from numojo.core.matrix import Matrix
23+
import numojo as nm
24+
2525
from numojo.core.item import Item, item
26+
from numojo.core.matrix import Matrix
27+
from numojo.core.ndarray import NDArray
2628
from numojo.core.ndshape import Shape, NDArrayShape
2729

2830
from numojo.core.complex.complex_dtype import CDType

numojo/routines/manipulation.mojo

Lines changed: 148 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1+
# ===----------------------------------------------------------------------=== #
2+
# Distributed under the Apache 2.0 License with LLVM Exceptions.
3+
# See LICENSE and the LLVM License for more information.
4+
# https://github.com/Mojo-Numerics-and-Algorithms-group/NuMojo/blob/main/LICENSE
5+
# https://llvm.org/LICENSE.txt
6+
# ===----------------------------------------------------------------------=== #
7+
18
"""
29
Array manipulation routines.
3-
410
"""
511

612
from memory import UnsafePointer, memcpy
@@ -12,7 +18,12 @@ from numojo.core.ndshape import NDArrayShape, Shape
1218
from numojo.core.ndstrides import NDArrayStrides
1319
import numojo.core.matrix as matrix
1420
from numojo.core.matrix import Matrix
15-
from numojo.core.utility import _list_of_flipped_range
21+
from numojo.core.utility import _list_of_flipped_range, _get_offset
22+
23+
# ===----------------------------------------------------------------------=== #
24+
# TODO:
25+
# - When `OwnData` is supported, re-write `broadcast_to()`.`
26+
# ===----------------------------------------------------------------------=== #
1627

1728
# ===----------------------------------------------------------------------=== #
1829
# Basic operations
@@ -272,6 +283,141 @@ fn transpose[dtype: DType](A: Matrix[dtype]) -> Matrix[dtype]:
272283
return B^
273284

274285

286+
# ===----------------------------------------------------------------------=== #
287+
# Changing number of dimensions
288+
# ===----------------------------------------------------------------------=== #
289+
290+
291+
fn broadcast_to[
292+
dtype: DType
293+
](a: NDArray[dtype], shape: NDArrayShape) raises -> NDArray[dtype]:
294+
if a.shape.ndim > shape.ndim:
295+
raise Error(
296+
String("Cannot broadcast shape {} to shape {}!").format(
297+
a.shape, shape
298+
)
299+
)
300+
301+
# Check whether broadcasting is possible or not.
302+
# We compare the shape from the trailing dimensions.
303+
304+
var b_strides = NDArrayStrides(
305+
shape
306+
) # Strides of b when refer to data of a
307+
308+
for i in range(a.shape.ndim):
309+
if a.shape[a.shape.ndim - 1 - i] == shape[shape.ndim - 1 - i]:
310+
b_strides[shape.ndim - 1 - i] = a.strides[a.shape.ndim - 1 - i]
311+
elif a.shape[a.shape.ndim - 1 - i] == 1:
312+
b_strides[shape.ndim - 1 - i] = 0
313+
else:
314+
raise Error(
315+
String("Cannot broadcast shape {} to shape {}!").format(
316+
a.shape, shape
317+
)
318+
)
319+
for i in range(shape.ndim - a.shape.ndim):
320+
b_strides[i] = 0
321+
322+
# Start broadcasting.
323+
# TODO: When `OwnData` is supported, re-write this part.
324+
# We just need to change the shape and strides and re-use the data.
325+
326+
var b = NDArray[dtype](shape) # Construct array of targeted shape.
327+
# TODO: `b.strides = b_strides` when OwnData
328+
329+
# Iterate all items in the new array and fill in correct values.
330+
for offset in range(b.size):
331+
var remainder = offset
332+
var indices = Item(ndim=b.ndim, initialized=False)
333+
334+
for i in range(b.ndim):
335+
indices[i], remainder = divmod(
336+
remainder,
337+
b.strides[
338+
i
339+
], # TODO: Change b.strides to NDArrayStrides(b.shape) when OwnData
340+
)
341+
342+
(b._buf.ptr + offset).init_pointee_copy(
343+
a._buf.ptr[
344+
_get_offset(indices, b_strides)
345+
] # TODO: Change b_strides to b.strides when OwnData
346+
)
347+
348+
return b^
349+
350+
351+
fn broadcast_to[
352+
dtype: DType
353+
](A: Matrix[dtype], shape: Tuple[Int, Int]) raises -> Matrix[dtype]:
354+
"""
355+
Broadcasts the vector to the given shape.
356+
357+
Example:
358+
359+
```console
360+
> from numojo import Matrix
361+
> a = Matrix.fromstring("1 2 3", shape=(1, 3))
362+
> print(mat.broadcast_to(a, (3, 3)))
363+
[[1.0 2.0 3.0]
364+
[1.0 2.0 3.0]
365+
[1.0 2.0 3.0]]
366+
> a = Matrix.fromstring("1 2 3", shape=(3, 1))
367+
> print(mat.broadcast_to(a, (3, 3)))
368+
[[1.0 1.0 1.0]
369+
[2.0 2.0 2.0]
370+
[3.0 3.0 3.0]]
371+
> a = Matrix.fromstring("1", shape=(1, 1))
372+
> print(mat.broadcast_to(a, (3, 3)))
373+
[[1.0 1.0 1.0]
374+
[1.0 1.0 1.0]
375+
[1.0 1.0 1.0]]
376+
> a = Matrix.fromstring("1 2", shape=(1, 2))
377+
> print(mat.broadcast_to(a, (1, 2)))
378+
[[1.0 2.0]]
379+
> a = Matrix.fromstring("1 2 3 4", shape=(2, 2))
380+
> print(mat.broadcast_to(a, (4, 2)))
381+
Unhandled exception caught during execution: Cannot broadcast shape 2x2 to shape 4x2!
382+
```
383+
"""
384+
385+
var B = Matrix[dtype](shape)
386+
if (A.shape[0] == shape[0]) and (A.shape[1] == shape[1]):
387+
B = A
388+
elif (A.shape[0] == 1) and (A.shape[1] == 1):
389+
B = Matrix.full[dtype](shape, A[0, 0])
390+
elif (A.shape[0] == 1) and (A.shape[1] == shape[1]):
391+
for i in range(shape[0]):
392+
memcpy(
393+
dest=B._buf.ptr.offset(shape[1] * i),
394+
src=A._buf.ptr,
395+
count=shape[1],
396+
)
397+
elif (A.shape[1] == 1) and (A.shape[0] == shape[0]):
398+
for i in range(shape[0]):
399+
for j in range(shape[1]):
400+
B._store(i, j, A._buf.ptr[i])
401+
else:
402+
var message = String(
403+
"Cannot broadcast shape {}x{} to shape {}x{}!"
404+
).format(A.shape[0], A.shape[1], shape[0], shape[1])
405+
raise Error(message)
406+
return B^
407+
408+
409+
fn broadcast_to[
410+
dtype: DType
411+
](A: Scalar[dtype], shape: Tuple[Int, Int]) raises -> Matrix[dtype]:
412+
"""
413+
Broadcasts the scalar to the given shape.
414+
"""
415+
416+
var B = Matrix[dtype](shape)
417+
B = Matrix.full[dtype](shape, A)
418+
return B^
419+
420+
275421
# ===----------------------------------------------------------------------=== #
276422
# Rearranging elements
277423
# ===----------------------------------------------------------------------=== #

tests/routines/test_manipulation.mojo

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,19 @@ def test_transpose():
7575
np.transpose(Anp, [1, 3, 0, 2]),
7676
"4-d `transpose` with arbitrary `axes` is broken.",
7777
)
78+
79+
80+
def test_broadcast():
81+
var np = Python.import_module("numpy")
82+
var a = nm.random.rand(Shape(2, 1, 3))
83+
var Anp = a.to_numpy()
84+
check(
85+
nm.broadcast_to(a, Shape(2, 2, 3)),
86+
np.broadcast_to(a.to_numpy(), (2, 2, 3)),
87+
"`broadcast_to` fails.",
88+
)
89+
check(
90+
nm.broadcast_to(a, Shape(2, 2, 2, 3)),
91+
np.broadcast_to(a.to_numpy(), (2, 2, 2, 3)),
92+
"`broadcast_to` fails.",
93+
)

0 commit comments

Comments
 (0)