diff --git a/stdlib/public/TensorFlow/Gradients.swift b/stdlib/public/TensorFlow/Gradients.swift index fee63b347af06..48eefffc8f5f9 100644 --- a/stdlib/public/TensorFlow/Gradients.swift +++ b/stdlib/public/TensorFlow/Gradients.swift @@ -210,7 +210,10 @@ extension Tensor where Scalar : TensorFlowFloatingPoint { ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { return (lhs + rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in - return (v.unbroadcast(toShape: lhsShape), v.unbroadcast(toShape: rhsShape)) + let (lhsAxes, rhsAxes) = + Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape) + return (v.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape), + v.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape)) }) } @@ -220,8 +223,10 @@ extension Tensor where Scalar : TensorFlowFloatingPoint { ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { return (lhs - rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in - return (v.unbroadcast(toShape: lhsShape), - -v.unbroadcast(toShape: rhsShape)) + let (lhsAxes, rhsAxes) = + Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape) + return (v.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape), + -v.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape)) }) } @@ -229,10 +234,12 @@ extension Tensor where Scalar : TensorFlowFloatingPoint { static func _vjpMultiply( lhs: Tensor, rhs: Tensor ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - return (lhs * rhs, { - [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in - ((rhs * v).unbroadcast(toShape: lhsShape), - (lhs * v).unbroadcast(toShape: rhsShape)) + return (lhs * rhs, { v in + let (lhsShape, rhsShape) = (lhs.shapeTensor, rhs.shapeTensor) + let (lhsAxes, rhsAxes) = + Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape) + return ((rhs * v).sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape), + (lhs * v).sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape)) }) } @@ -240,10 +247,14 @@ extension Tensor where Scalar : TensorFlowFloatingPoint { static func _vjpDivide( lhs: Tensor, rhs: Tensor ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - return (lhs / rhs, { - [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in - ((v / rhs).unbroadcast(toShape: lhsShape), - ((-lhs) / rhs.squared() * v).unbroadcast(toShape: rhsShape)) + return (lhs / rhs, { v in + let (lhsShape, rhsShape) = (lhs.shapeTensor, rhs.shapeTensor) + let (lhsAxes, rhsAxes) = + Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape) + return ((v / rhs).sum(squeezingAxes: lhsAxes) + .reshaped(toShape: lhsShape), + (-lhs / rhs.squared() * v).sum(squeezingAxes: rhsAxes) + .reshaped(toShape: rhsShape)) }) } } @@ -267,14 +278,14 @@ extension Tensor where Scalar : TensorFlowFloatingPoint { static func _vjpSubtract( lhs: Tensor, rhs: Scalar ) -> (Tensor, (Tensor) -> (Tensor, Scalar)) { - return (lhs - rhs, { v in (v, 0 - v.sum().scalarized()) }) + return (lhs - rhs, { v in (v, -v.sum().scalarized()) }) } @inlinable static func _vjpSubtract( lhs: Scalar, rhs: Tensor ) -> (Tensor, (Tensor) -> (Scalar, Tensor)) { - return (lhs - rhs, { v in (v.sum().scalarized(), 0 - v) }) + return (lhs - rhs, { v in (v.sum().scalarized(), -v) }) } @inlinable @@ -296,7 +307,7 @@ extension Tensor where Scalar : TensorFlowFloatingPoint { lhs: Tensor, rhs: Scalar ) -> (Tensor, (Tensor) -> (Tensor, Scalar)) { return (lhs / rhs, { v in - (v / rhs, (v * (0 - lhs) / Tensor(rhs).squared()).sum().scalarized()) + (v / rhs, (v * -lhs / Tensor(rhs).squared()).sum().scalarized()) }) } @@ -317,7 +328,10 @@ func _vjpMinMaxHelper( let denom = 1 + Tensor(x .== y) let dfdx = vector * Tensor(x .== originalValue) / denom let dfdy = vector * Tensor(y .== originalValue) / denom - return (dfdx.unbroadcast(like: x), dfdy.unbroadcast(like: y)) + let (xShape, yShape) = (x.shapeTensor, y.shapeTensor) + let (xAxes, yAxes) = Raw.broadcastGradientArgs(s0: xShape, s1: yShape) + return (dfdx.sum(squeezingAxes: xAxes).reshaped(toShape: xShape), + dfdy.sum(squeezingAxes: yAxes).reshaped(toShape: yShape)) } @inlinable @@ -325,8 +339,9 @@ func _vjpMax( _ x: Tensor, _ y: Tensor ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { let value = max(x, y) - return (value, - { v in _vjpMinMaxHelper(x, y, originalValue: value, vector: v) }) + return (value, { v in + _vjpMinMaxHelper(x, y, originalValue: value, vector: v) + }) } @inlinable @@ -334,8 +349,9 @@ func _vjpMin( _ x: Tensor, _ y: Tensor ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { let value = min(x, y) - return (value, - { v in _vjpMinMaxHelper(x, y, originalValue: value, vector: v) }) + return (value, { v in + _vjpMinMaxHelper(x, y, originalValue: value, vector: v) + }) } @inlinable @@ -344,8 +360,12 @@ func _vjpPow( ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { let value = pow(x, y) return (value, { v in - ((v * y * pow(x, y-1)).unbroadcast(like: x), - (v * log(x) * value).unbroadcast(like: y)) + let (xShape, yShape) = (x.shapeTensor, y.shapeTensor) + let (xAxes, yAxes) = Raw.broadcastGradientArgs(s0: xShape, s1: yShape) + return ((v * y * pow(x, y-1)).sum(squeezingAxes: xAxes) + .reshaped(toShape: xShape), + (v * log(x) * value).sum(squeezingAxes: yAxes) + .reshaped(toShape: yShape)) }) } diff --git a/stdlib/public/TensorFlow/Ops.swift b/stdlib/public/TensorFlow/Ops.swift index 9f542deb7018b..1f0d6508b0db5 100644 --- a/stdlib/public/TensorFlow/Ops.swift +++ b/stdlib/public/TensorFlow/Ops.swift @@ -1665,30 +1665,6 @@ public extension Tensor { func broadcast(like other: Tensor) -> Tensor { return broadcast(toShape: other.shapeTensor) } -} - -public extension Tensor where Scalar : Numeric { - @inlinable - func unbroadcast(toShape otherShape: Tensor) -> Tensor { - let rankDiff = (rankTensor - otherShape.scalarCountTensor).rankLifted() - let ones: Tensor = Raw.fill(dims: rankDiff, value: Tensor(1)) - let paddedShape = ones ++ otherShape - let nonEqualIndices = paddedShape .!= shapeTensor - let broadcastIndices = Raw.where_(nonEqualIndices).flattened() - let unbroadcasted: Tensor = Raw.sum( - self, reductionIndices: Tensor(broadcastIndices), keepDims: false) - return Raw.reshape(unbroadcasted, shape: otherShape) - } - - @inlinable @inline(__always) - func unbroadcast(like other: Tensor) -> Tensor { - return unbroadcast(toShape: other.shapeTensor) - } - - @inlinable @inline(__always) - func unbroadcast(to shape: TensorShape) -> Tensor { - return unbroadcast(toShape: Tensor(shape.dimensions.map(Int32.init))) - } @inlinable @inline(__always) static func .= (lhs: inout Tensor, rhs: Tensor) { diff --git a/test/TensorFlowRuntime/tensor_autodiff_runtime.swift b/test/TensorFlowRuntime/tensor_autodiff_runtime.swift index 69cbe26b88284..adc46ecda88b0 100644 --- a/test/TensorFlowRuntime/tensor_autodiff_runtime.swift +++ b/test/TensorFlowRuntime/tensor_autodiff_runtime.swift @@ -25,6 +25,18 @@ TensorADTests.testAllBackends("TestSimpleGrad") { expectEqual([[20], [40]], gradient(at: [[10], [20]], in: square)) } +// TODO: This is also failing! +TensorADTests.testAllBackends("TestBroadcastingGrad") { + func foo(_ x: Tensor, _ y: Tensor) -> Tensor { + return x * y + x + } + let x = Tensor(ones: [1, 2, 1, 4]) + let y = Tensor(ones: [4, 1, 3, 1]) + let (dx, dy) = gradient(at: x, y, in: foo) + expectEqual(x.shape, dx.shape) + expectEqual(y.shape, dx.shape) +} + TensorADTests.testAllBackends("TestGenericGrad") { func square(_ x: Tensor) -> Tensor { return x * x @@ -219,6 +231,7 @@ TensorADTests.testAllBackends("Differentiate global") { } TensorADTests.testAllBackends("Side effects") { +///* This is failing reshape for some reason let foo: @differentiable (Tensor) -> Tensor = { x in var a = x a = a + x @@ -226,6 +239,7 @@ TensorADTests.testAllBackends("Side effects") { return a + x } expectEqual(Tensor([8, 8]), pullback(at: Tensor(4), in: foo)([1, 1])) +//*/ func bar(x: Tensor) -> Tensor { var a = x