diff --git a/Sources/TensorFlow/Core/DifferentialOperators.swift b/Sources/TensorFlow/Core/DifferentialOperators.swift index 1092c2cea..74c18818b 100644 --- a/Sources/TensorFlow/Core/DifferentialOperators.swift +++ b/Sources/TensorFlow/Core/DifferentialOperators.swift @@ -21,7 +21,7 @@ public extension Differentiable { func gradient( in f: @differentiable (Self) -> Tensor ) -> TangentVector { - return self.pullback(in: f)(Tensor(1)) + return self.valueWithGradient(in: f).1 } @inlinable @@ -29,6 +29,7 @@ public extension Differentiable { in f: @differentiable (Self) -> Tensor ) -> (value: Tensor, gradient: TangentVector) { let (y, pb) = self.valueWithPullback(in: f) + precondition(y.rank == 0) return (y, pb(Tensor(1))) } @@ -37,7 +38,7 @@ public extension Differentiable { at x: T, in f: @differentiable (Self, T) -> Tensor ) -> (TangentVector, T.TangentVector) { - return self.pullback(at: x, in: f)(Tensor(1)) + return self.valueWithGradient(at: x, in: f).1 } @inlinable @@ -46,6 +47,7 @@ public extension Differentiable { in f: @differentiable (Self, T) -> Tensor ) -> (value: Tensor, gradient: (TangentVector, T.TangentVector)) { let (y, pb) = self.valueWithPullback(at: x, in: f) + precondition(y.rank == 0) return (y, pb(Tensor(1))) } } @@ -63,6 +65,7 @@ public func valueWithGradient( ) -> (value: Tensor, gradient: T.TangentVector) where T: Differentiable, R: TensorFlowFloatingPoint { let (y, pullback) = valueWithPullback(at: x, in: f) + precondition(y.rank == 0) return (y, pullback(Tensor(1))) } @@ -74,6 +77,7 @@ public func valueWithGradient( ) -> (value: Tensor, gradient: (T.TangentVector, U.TangentVector)) where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint { let (y, pullback) = valueWithPullback(at: x, y, in: f) + precondition(y.rank == 0) return (y, pullback(Tensor(1))) } @@ -86,6 +90,7 @@ public func valueWithGradient( // ) -> (value: Tensor, gradient: (T.TangentVector, U.TangentVector, V.TangentVector)) // where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint { // let (y, pullback) = valueWithPullback(at: x, y, z, in: f) +// precondition(y.rank == 0) // return (y, pullback(Tensor(1))) // } @@ -124,7 +129,7 @@ public func gradient( at x: T, in f: @differentiable (T) -> Tensor ) -> T.TangentVector where T: Differentiable, R: TensorFlowFloatingPoint { - return pullback(at: x, in: f)(Tensor(1)) + return valueWithGradient(at: x, in: f).1 } @inlinable @@ -134,7 +139,7 @@ public func gradient( in f: @differentiable (T, U) -> Tensor ) -> (T.TangentVector, U.TangentVector) where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint { - return pullback(at: x, y, in: f)(Tensor(1)) + return valueWithGradient(at: x, y, in: f).1 } // @inlinable @@ -145,7 +150,7 @@ public func gradient( // in f: @differentiable (T, U, V) -> Tensor // ) -> (T.TangentVector, U.TangentVector, V.TangentVector) // where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint { -// return pullback(at: x, y, z, in: f)(Tensor(1)) +// return valueWithGradient(at: x, y, z, in: f).1 // } // Gradient (curried) diff --git a/Sources/TensorFlow/Core/Tensor.swift b/Sources/TensorFlow/Core/Tensor.swift index 01f5b9063..a36bddfdb 100644 --- a/Sources/TensorFlow/Core/Tensor.swift +++ b/Sources/TensorFlow/Core/Tensor.swift @@ -511,20 +511,12 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint { lhs: Tensor, rhs: Tensor ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - return (lhs + rhs, { [ - lhsShape = lhs.shape, - rhsShape = rhs.shape, - lhsShapeTensor = lhs.shapeTensor, - rhsShapeTensor = rhs.shapeTensor] v in - var lhsGrad = v - var rhsGrad = v - if lhsGrad.shape != lhsShape { - lhsGrad = lhsGrad.unbroadcasted(toShape: lhsShapeTensor) - } - if rhsGrad.shape != rhsShape { - rhsGrad = rhsGrad.unbroadcasted(toShape: rhsShapeTensor) - } - return (lhsGrad, rhsGrad) + return (lhs + rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in + let lhsGrad = v + let rhsGrad = lhsGrad + let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape) + return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape), + rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape)) }) } @@ -533,20 +525,12 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint { lhs: Tensor, rhs: Tensor ) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - return (lhs - rhs, { [ - lhsShape = lhs.shape, - rhsShape = rhs.shape, - lhsShapeTensor = lhs.shapeTensor, - rhsShapeTensor = rhs.shapeTensor] v in - var lhsGrad = v - var rhsGrad = -v - if lhsGrad.shape != lhsShape { - lhsGrad = lhsGrad.unbroadcasted(toShape: lhsShapeTensor) - } - if rhsGrad.shape != rhsShape { - rhsGrad = rhsGrad.unbroadcasted(toShape: rhsShapeTensor) - } - return (lhsGrad, rhsGrad) + return (lhs - rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in + let lhsGrad = v + let rhsGrad = -lhsGrad + let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape) + return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape), + rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape)) }) } } diff --git a/Sources/TensorFlow/Operators/Math.swift b/Sources/TensorFlow/Operators/Math.swift index cc3cb23fa..b0dc519db 100644 --- a/Sources/TensorFlow/Operators/Math.swift +++ b/Sources/TensorFlow/Operators/Math.swift @@ -43,20 +43,12 @@ extension Tensor: VectorNumeric where Scalar: Numeric { internal extension Tensor where Scalar: TensorFlowFloatingPoint { @inlinable static func _vjpMultiply(lhs: Tensor, rhs: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - return (lhs * rhs, { [ - lhsShape = lhs.shape, - rhsShape = rhs.shape, - lhsShapeTensor = lhs.shapeTensor, - rhsShapeTensor = rhs.shapeTensor] v in - var lhsGrad = rhs * v - var rhsGrad = lhs * v - if lhsGrad.shape != lhsShape { - lhsGrad = lhsGrad.unbroadcasted(toShape: lhsShapeTensor) - } - if rhsGrad.shape != rhsShape { - rhsGrad = rhsGrad.unbroadcasted(toShape: rhsShapeTensor) - } - return (lhsGrad, rhsGrad) + return (lhs * rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in + let lhsGrad = rhs * v + let rhsGrad = lhs * v + let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape) + return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape), + rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape)) }) } } @@ -236,12 +228,12 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint { @inlinable static func _vjpSubtract(lhs: Tensor, rhs: Scalar) -> (Tensor, (Tensor) -> (Tensor, Scalar)) { - return (lhs - rhs, { v in (v, 0 - v.sum().scalarized()) }) + return (lhs - rhs, { v in (v, -v.sum().scalarized()) }) } @inlinable static func _vjpSubtract(lhs: Scalar, rhs: Tensor) -> (Tensor, (Tensor) -> (Scalar, Tensor)) { - return (lhs - rhs, { v in (v.sum().scalarized(), 0 - v) }) + return (lhs - rhs, { v in (v.sum().scalarized(), -v) }) } @inlinable @@ -256,27 +248,19 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint { @inlinable static func _vjpDivide(lhs: Tensor, rhs: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor)) { - return (lhs / rhs, { [ - lhsShape = lhs.shape, - rhsShape = rhs.shape, - lhsShapeTensor = lhs.shapeTensor, - rhsShapeTensor = rhs.shapeTensor] v in - var lhsGrad = v / rhs - var rhsGrad = (-lhs) / rhs.squared() * v - if lhsGrad.shape != lhsShape { - lhsGrad = lhsGrad.unbroadcasted(toShape: lhsShapeTensor) - } - if rhsGrad.shape != rhsShape { - rhsGrad = rhsGrad.unbroadcasted(toShape: rhsShapeTensor) - } - return (lhsGrad, rhsGrad) + return (lhs / rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in + let lhsGrad = v / rhs + let rhsGrad = -lhs / rhs.squared() * v + let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape) + return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape), + rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape)) }) } @inlinable static func _vjpDivide(lhs: Tensor, rhs: Scalar) -> (Tensor, (Tensor) -> (Tensor, Scalar)) { return (lhs / rhs, { v in - (v / rhs, (v * (0 - lhs) / Tensor(rhs).squared()).sum().scalarized()) + (v / rhs, (v * -lhs / Tensor(rhs).squared()).sum().scalarized()) }) } @@ -704,15 +688,12 @@ internal func _vjpPow( let value = pow(x, y) return (value, { v in let safeX = x.replacing(with: Tensor(onesLike: x), where: x .<= 0) - var gradX = v * y * pow(x, y - 1) - var gradY = value * v * log(safeX) - if gradX.shape != x.shape { - gradX = gradX.unbroadcasted(like: x) - } - if gradY.shape != y.shape { - gradY = gradY.unbroadcasted(like: y) - } - return (gradX, gradY) + let lhsGrad = v * y * pow(x, y - 1) + let rhsGrad = value * v * log(safeX) + let (lhsShape, rhsShape) = (x.shapeTensor, y.shapeTensor) + let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape) + return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape), + rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape)) }) } @@ -798,15 +779,12 @@ internal func _vjpMinMaxHelper( seed: Tensor ) -> (Tensor, Tensor) { let denominator = 1 + Tensor(x .== y) - var gradX = seed * Tensor(x .== originalValue) / denominator - var gradY = seed * Tensor(y .== originalValue) / denominator - if gradX.shape != x.shape { - gradX = gradX.unbroadcasted(like: x) - } - if gradY.shape != y.shape { - gradY = gradY.unbroadcasted(like: y) - } - return (gradX, gradY) + let lhsGrad = seed * Tensor(x .== originalValue) / denominator + let rhsGrad = seed * Tensor(y .== originalValue) / denominator + let (lhsShape, rhsShape) = (x.shapeTensor, y.shapeTensor) + let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape) + return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape), + rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape)) } //===------------------------------------------------------------------------------------------===// diff --git a/Tests/TensorFlowTests/OperatorTests/MathTests.swift b/Tests/TensorFlowTests/OperatorTests/MathTests.swift index 2a208b258..978834627 100644 --- a/Tests/TensorFlowTests/OperatorTests/MathTests.swift +++ b/Tests/TensorFlowTests/OperatorTests/MathTests.swift @@ -199,6 +199,17 @@ final class MathOperatorTests: XCTestCase { XCTAssertEqual(0.816997, Double(prediction.scalars[0]), accuracy: 0.0001) } + func testBroadcastedAddGradient() { + func foo(_ x: Tensor, _ y: Tensor) -> Tensor { + return (x + y).sum() + } + let x = Tensor(ones: [1, 2, 1, 4]) + let y = Tensor(ones: [4, 1, 3, 1]) + let (dx, dy) = gradient(at: x, y, in: foo) + XCTAssertEqual(x.shape, dx.shape) + XCTAssertEqual(y.shape, dy.shape) + } + static var allTests = [ ("testReduction", testReduction), ("testArgmax", testArgmax), @@ -209,6 +220,7 @@ final class MathOperatorTests: XCTestCase { ("testMultiOpMath", testMultiOpMath), ("testXWPlusB", testXWPlusB), ("testXORInference", testXORInference), - ("testMLPClassifierStruct", testMLPClassifierStruct) + ("testMLPClassifierStruct", testMLPClassifierStruct), + ("testBroadcastedAddGradient", testBroadcastedAddGradient) ] }