Skip to content

Broadcast refactor #3220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 19, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 38 additions & 70 deletions src/type/matrix/utils/broadcast.js
Original file line number Diff line number Diff line change
@@ -1,72 +1,40 @@
import { checkBroadcastingRules } from '../../../utils/array.js'
import { factory } from '../../../utils/factory.js'

const name = 'broadcast'

const dependancies = ['concat']

export const createBroadcast = /* #__PURE__ */ factory(
name, dependancies,
({ concat }) => {
/**
* Broadcasts two matrices, and return both in an array
* It checks if it's possible with broadcasting rules
*
* @param {Matrix} A First Matrix
* @param {Matrix} B Second Matrix
*
* @return {Matrix[]} [ broadcastedA, broadcastedB ]
*/
return function (A, B) {
const N = Math.max(A._size.length, B._size.length) // max number of dims
if (A._size.length === B._size.length) {
if (A._size.every((dim, i) => dim === B._size[i])) {
// If matrices have the same size return them
return [A, B]
}
}

const sizeA = _padLeft(A._size, N, 0) // pad to the left to align dimensions to the right
const sizeB = _padLeft(B._size, N, 0) // pad to the left to align dimensions to the right

// calculate the max dimensions
const sizeMax = []

for (let dim = 0; dim < N; dim++) {
sizeMax[dim] = Math.max(sizeA[dim], sizeB[dim])
}

// check if the broadcasting rules applyes for both matrices
checkBroadcastingRules(sizeA, sizeMax)
checkBroadcastingRules(sizeB, sizeMax)

// reshape A or B if needed to make them ready for concat
let AA = A.clone()
let BB = B.clone()
if (AA._size.length < N) {
AA.reshape(_padLeft(AA._size, N, 1))
} else if (BB._size.length < N) {
BB.reshape(_padLeft(BB._size, N, 1))
}

// stretches the matrices on each dimension to make them the same size
for (let dim = 0; dim < N; dim++) {
if (AA._size[dim] < sizeMax[dim]) { AA = _stretch(AA, sizeMax[dim], dim) }
if (BB._size[dim] < sizeMax[dim]) { BB = _stretch(BB, sizeMax[dim], dim) }
}

// return the array with the two broadcasted matrices
return [AA, BB]
}

function _padLeft (shape, N, filler) {
// pads an array of dimensions with numbers to the left, unitl the number of dimensions is N
return [...Array(N - shape.length).fill(filler), ...shape]
}
import { broadcastSizes, broadcastTo as broadcastArrayTo } from '../../../utils/array.js'
import { deepStrictEqual } from '../../../utils/object.js'

/**
* Broadcasts two matrices, and return both in an array
* It checks if it's possible with broadcasting rules
*
* @param {Matrix} A First Matrix
* @param {Matrix} B Second Matrix
*
* @return {Matrix[]} [ broadcastedA, broadcastedB ]
*/

export function broadcast (A, B) {
if (deepStrictEqual(A._size, B._size)) {
// If matrices have the same size return them
return [A, B]
}

function _stretch (arrayToStretch, sizeToStretch, dimToStretch) {
// stretches a matrix up to a certain size in a certain dimension
return concat(...Array(sizeToStretch).fill(arrayToStretch), dimToStretch)
}
// calculate the broadcasted sizes
const newSize = broadcastSizes(A._size, B._size)

// return the array with the two broadcasted matrices
return [A, B].map(M => _broadcastTo(M, newSize))
}

/**
* Broadcasts a matrix to the given size.
*
* @param {Matrix} M - The matrix to be broadcasted.
* @param {number[]} size - The desired size of the broadcasted matrix.
* @returns {Matrix} The broadcasted matrix.
* @throws {Error} If the size parameter is not an array of numbers.
*/
function _broadcastTo (M, size) {
if (deepStrictEqual(M._size, size)) {
return M
}
)
return M.create(broadcastArrayTo(M._data, size), M._datatype)
}
7 changes: 3 additions & 4 deletions src/type/matrix/utils/matrixAlgorithmSuite.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@ import { factory } from '../../../utils/factory.js'
import { extend } from '../../../utils/object.js'
import { createMatAlgo13xDD } from './matAlgo13xDD.js'
import { createMatAlgo14xDs } from './matAlgo14xDs.js'
import { createBroadcast } from './broadcast.js'
import { broadcast } from './broadcast.js'

const name = 'matrixAlgorithmSuite'
const dependencies = ['typed', 'matrix', 'concat']
const dependencies = ['typed', 'matrix']

export const createMatrixAlgorithmSuite = /* #__PURE__ */ factory(
name, dependencies, ({ typed, matrix, concat }) => {
name, dependencies, ({ typed, matrix }) => {
const matAlgo13xDD = createMatAlgo13xDD({ typed })
const matAlgo14xDs = createMatAlgo14xDs({ typed })
const broadcast = createBroadcast({ concat })

/**
* Return a signatures object with the usual boilerplate of
Expand Down
4 changes: 2 additions & 2 deletions test/unit-tests/function/arithmetic/dotDivide.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ describe('dotDivide', function () {
assert.deepStrictEqual(dotDivide(a, b), math.matrix([[1 / 5, Infinity], [0, 4 / 8]]))
})

it('should throw an error when dividing element-wise with differing size', function () {
assert.throws(function () { dotDivide(math.sparse([[1, 2], [3, 4]]), math.sparse([[1]])) })
it('should throw an error when dividing element-wise with differing size is not broadcastable', function () {
assert.throws(function () { dotDivide(math.sparse([[1, 2], [3, 4]]), math.sparse([1, 2, 3])) })
})
})

Expand Down
4 changes: 1 addition & 3 deletions test/unit-tests/type/matrix/utils/broadcast.test.js
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import assert from 'assert'
import math from '../../../../../src/defaultInstance.js'
import { createBroadcast } from '../../../../../src/type/matrix/utils/broadcast.js'
const concat = math.concat
import { broadcast } from '../../../../../src/type/matrix/utils/broadcast.js'
const matrix = math.matrix
const broadcast = createBroadcast({ concat })

describe('broadcast', function () {
it('should return matrices as such if they are the same size', function () {
Expand Down