1+ # ===----------------------------------------------------------------------=== #
2+ # Distributed under the Apache 2.0 License with LLVM Exceptions.
3+ # See LICENSE and the LLVM License for more information.
4+ # https://github.com/Mojo-Numerics-and-Algorithms-group/NuMojo/blob/main/LICENSE
5+ # https://llvm.org/LICENSE.txt
6+ # ===----------------------------------------------------------------------=== #
7+
18"""
29Array manipulation routines.
3-
410"""
511
612from memory import UnsafePointer, memcpy
@@ -12,7 +18,12 @@ from numojo.core.ndshape import NDArrayShape, Shape
1218from numojo.core.ndstrides import NDArrayStrides
1319import numojo.core.matrix as matrix
1420from numojo.core.matrix import Matrix
15- from numojo.core.utility import _list_of_flipped_range
21+ from numojo.core.utility import _list_of_flipped_range, _get_offset
22+
23+ # ===----------------------------------------------------------------------=== #
24+ # TODO :
25+ # - When `OwnData` is supported, re-write `broadcast_to()`.`
26+ # ===----------------------------------------------------------------------=== #
1627
1728# ===----------------------------------------------------------------------=== #
1829# Basic operations
@@ -272,6 +283,141 @@ fn transpose[dtype: DType](A: Matrix[dtype]) -> Matrix[dtype]:
272283 return B^
273284
274285
286+ # ===----------------------------------------------------------------------=== #
287+ # Changing number of dimensions
288+ # ===----------------------------------------------------------------------=== #
289+
290+
291+ fn broadcast_to [
292+ dtype : DType
293+ ](a : NDArray[dtype], shape : NDArrayShape) raises -> NDArray[dtype]:
294+ if a.shape.ndim > shape.ndim:
295+ raise Error(
296+ String(" Cannot broadcast shape {} to shape {} !" ).format(
297+ a.shape, shape
298+ )
299+ )
300+
301+ # Check whether broadcasting is possible or not.
302+ # We compare the shape from the trailing dimensions.
303+
304+ var b_strides = NDArrayStrides(
305+ shape
306+ ) # Strides of b when refer to data of a
307+
308+ for i in range (a.shape.ndim):
309+ if a.shape[a.shape.ndim - 1 - i] == shape[shape.ndim - 1 - i]:
310+ b_strides[shape.ndim - 1 - i] = a.strides[a.shape.ndim - 1 - i]
311+ elif a.shape[a.shape.ndim - 1 - i] == 1 :
312+ b_strides[shape.ndim - 1 - i] = 0
313+ else :
314+ raise Error(
315+ String(" Cannot broadcast shape {} to shape {} !" ).format(
316+ a.shape, shape
317+ )
318+ )
319+ for i in range (shape.ndim - a.shape.ndim):
320+ b_strides[i] = 0
321+
322+ # Start broadcasting.
323+ # TODO : When `OwnData` is supported, re-write this part.
324+ # We just need to change the shape and strides and re-use the data.
325+
326+ var b = NDArray[dtype](shape) # Construct array of targeted shape.
327+ # TODO : `b.strides = b_strides` when OwnData
328+
329+ # Iterate all items in the new array and fill in correct values.
330+ for offset in range (b.size):
331+ var remainder = offset
332+ var indices = Item(ndim = b.ndim, initialized = False )
333+
334+ for i in range (b.ndim):
335+ indices[i], remainder = divmod (
336+ remainder,
337+ b.strides[
338+ i
339+ ], # TODO : Change b.strides to NDArrayStrides(b.shape) when OwnData
340+ )
341+
342+ (b._buf.ptr + offset).init_pointee_copy(
343+ a._buf.ptr[
344+ _get_offset(indices, b_strides)
345+ ] # TODO : Change b_strides to b.strides when OwnData
346+ )
347+
348+ return b^
349+
350+
351+ fn broadcast_to [
352+ dtype : DType
353+ ](A : Matrix[dtype], shape : Tuple[Int, Int]) raises -> Matrix[dtype]:
354+ """
355+ Broadcasts the vector to the given shape.
356+
357+ Example:
358+
359+ ```console
360+ > from numojo import Matrix
361+ > a = Matrix.fromstring("1 2 3", shape=(1, 3))
362+ > print(mat.broadcast_to(a, (3, 3)))
363+ [[1.0 2.0 3.0]
364+ [1.0 2.0 3.0]
365+ [1.0 2.0 3.0]]
366+ > a = Matrix.fromstring("1 2 3", shape=(3, 1))
367+ > print(mat.broadcast_to(a, (3, 3)))
368+ [[1.0 1.0 1.0]
369+ [2.0 2.0 2.0]
370+ [3.0 3.0 3.0]]
371+ > a = Matrix.fromstring("1", shape=(1, 1))
372+ > print(mat.broadcast_to(a, (3, 3)))
373+ [[1.0 1.0 1.0]
374+ [1.0 1.0 1.0]
375+ [1.0 1.0 1.0]]
376+ > a = Matrix.fromstring("1 2", shape=(1, 2))
377+ > print(mat.broadcast_to(a, (1, 2)))
378+ [[1.0 2.0]]
379+ > a = Matrix.fromstring("1 2 3 4", shape=(2, 2))
380+ > print(mat.broadcast_to(a, (4, 2)))
381+ Unhandled exception caught during execution: Cannot broadcast shape 2x2 to shape 4x2!
382+ ```
383+ """
384+
385+ var B = Matrix[dtype](shape)
386+ if (A.shape[0 ] == shape[0 ]) and (A.shape[1 ] == shape[1 ]):
387+ B = A
388+ elif (A.shape[0 ] == 1 ) and (A.shape[1 ] == 1 ):
389+ B = Matrix.full[dtype](shape, A[0 , 0 ])
390+ elif (A.shape[0 ] == 1 ) and (A.shape[1 ] == shape[1 ]):
391+ for i in range (shape[0 ]):
392+ memcpy(
393+ dest = B._buf.ptr.offset(shape[1 ] * i),
394+ src = A._buf.ptr,
395+ count = shape[1 ],
396+ )
397+ elif (A.shape[1 ] == 1 ) and (A.shape[0 ] == shape[0 ]):
398+ for i in range (shape[0 ]):
399+ for j in range (shape[1 ]):
400+ B._store(i, j, A._buf.ptr[i])
401+ else :
402+ var message = String(
403+ " Cannot broadcast shape {} x{} to shape {} x{} !"
404+ ).format(A.shape[0 ], A.shape[1 ], shape[0 ], shape[1 ])
405+ raise Error(message)
406+ return B^
407+
408+
409+ fn broadcast_to [
410+ dtype : DType
411+ ](A : Scalar[dtype], shape : Tuple[Int, Int]) raises -> Matrix[dtype]:
412+ """
413+ Broadcasts the scalar to the given shape.
414+ """
415+
416+ var B = Matrix[dtype](shape)
417+ B = Matrix.full[dtype](shape, A)
418+ return B^
419+
420+
275421# ===----------------------------------------------------------------------=== #
276422# Rearranging elements
277423# ===----------------------------------------------------------------------=== #
0 commit comments