Skip to content
This repository was archived by the owner on Jul 1, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions Sources/TensorFlow/Core/DifferentialOperators.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,15 @@ public extension Differentiable {
func gradient<R: TensorFlowFloatingPoint>(
in f: @differentiable (Self) -> Tensor<R>
) -> TangentVector {
return self.pullback(in: f)(Tensor<R>(1))
return self.valueWithGradient(in: f).1
}

@inlinable
func valueWithGradient<R: TensorFlowFloatingPoint>(
in f: @differentiable (Self) -> Tensor<R>
) -> (value: Tensor<R>, gradient: TangentVector) {
let (y, pb) = self.valueWithPullback(in: f)
precondition(y.rank == 0)
return (y, pb(Tensor<R>(1)))
}

Expand All @@ -37,7 +38,7 @@ public extension Differentiable {
at x: T,
in f: @differentiable (Self, T) -> Tensor<R>
) -> (TangentVector, T.TangentVector) {
return self.pullback(at: x, in: f)(Tensor<R>(1))
return self.valueWithGradient(at: x, in: f).1
}

@inlinable
Expand All @@ -46,6 +47,7 @@ public extension Differentiable {
in f: @differentiable (Self, T) -> Tensor<R>
) -> (value: Tensor<R>, gradient: (TangentVector, T.TangentVector)) {
let (y, pb) = self.valueWithPullback(at: x, in: f)
precondition(y.rank == 0)
return (y, pb(Tensor<R>(1)))
}
}
Expand All @@ -63,6 +65,7 @@ public func valueWithGradient<T, R>(
) -> (value: Tensor<R>, gradient: T.TangentVector)
where T: Differentiable, R: TensorFlowFloatingPoint {
let (y, pullback) = valueWithPullback(at: x, in: f)
precondition(y.rank == 0)
return (y, pullback(Tensor<R>(1)))
}

Expand All @@ -74,6 +77,7 @@ public func valueWithGradient<T, U, R>(
) -> (value: Tensor<R>, gradient: (T.TangentVector, U.TangentVector))
where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint {
let (y, pullback) = valueWithPullback(at: x, y, in: f)
precondition(y.rank == 0)
return (y, pullback(Tensor<R>(1)))
}

Expand All @@ -86,6 +90,7 @@ public func valueWithGradient<T, U, R>(
// ) -> (value: Tensor<R>, gradient: (T.TangentVector, U.TangentVector, V.TangentVector))
// where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint {
// let (y, pullback) = valueWithPullback(at: x, y, z, in: f)
// precondition(y.rank == 0)
// return (y, pullback(Tensor<R>(1)))
// }

Expand Down Expand Up @@ -124,7 +129,7 @@ public func gradient<T, R>(
at x: T,
in f: @differentiable (T) -> Tensor<R>
) -> T.TangentVector where T: Differentiable, R: TensorFlowFloatingPoint {
return pullback(at: x, in: f)(Tensor<R>(1))
return valueWithGradient(at: x, in: f).1
}

@inlinable
Expand All @@ -134,7 +139,7 @@ public func gradient<T, U, R>(
in f: @differentiable (T, U) -> Tensor<R>
) -> (T.TangentVector, U.TangentVector)
where T: Differentiable, U: Differentiable, R: TensorFlowFloatingPoint {
return pullback(at: x, y, in: f)(Tensor<R>(1))
return valueWithGradient(at: x, y, in: f).1
}

// @inlinable
Expand All @@ -145,7 +150,7 @@ public func gradient<T, U, R>(
// in f: @differentiable (T, U, V) -> Tensor<R>
// ) -> (T.TangentVector, U.TangentVector, V.TangentVector)
// where T: Differentiable, U: Differentiable, V: Differentiable, R: TensorFlowFloatingPoint {
// return pullback(at: x, y, z, in: f)(Tensor<R>(1))
// return valueWithGradient(at: x, y, z, in: f).1
// }

// Gradient (curried)
Expand Down
40 changes: 12 additions & 28 deletions Sources/TensorFlow/Core/Tensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -511,20 +511,12 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
lhs: Tensor,
rhs: Tensor
) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
return (lhs + rhs, { [
lhsShape = lhs.shape,
rhsShape = rhs.shape,
lhsShapeTensor = lhs.shapeTensor,
rhsShapeTensor = rhs.shapeTensor] v in
var lhsGrad = v
var rhsGrad = v
if lhsGrad.shape != lhsShape {
lhsGrad = lhsGrad.unbroadcasted(toShape: lhsShapeTensor)
}
if rhsGrad.shape != rhsShape {
rhsGrad = rhsGrad.unbroadcasted(toShape: rhsShapeTensor)
}
return (lhsGrad, rhsGrad)
return (lhs + rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
let lhsGrad = v
let rhsGrad = lhsGrad
let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
})
}

Expand All @@ -533,20 +525,12 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {
lhs: Tensor,
rhs: Tensor
) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
return (lhs - rhs, { [
lhsShape = lhs.shape,
rhsShape = rhs.shape,
lhsShapeTensor = lhs.shapeTensor,
rhsShapeTensor = rhs.shapeTensor] v in
var lhsGrad = v
var rhsGrad = -v
if lhsGrad.shape != lhsShape {
lhsGrad = lhsGrad.unbroadcasted(toShape: lhsShapeTensor)
}
if rhsGrad.shape != rhsShape {
rhsGrad = rhsGrad.unbroadcasted(toShape: rhsShapeTensor)
}
return (lhsGrad, rhsGrad)
return (lhs - rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
let lhsGrad = v
let rhsGrad = -lhsGrad
let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
})
}
}
Expand Down
76 changes: 27 additions & 49 deletions Sources/TensorFlow/Operators/Math.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,20 +43,12 @@ extension Tensor: VectorNumeric where Scalar: Numeric {
internal extension Tensor where Scalar: TensorFlowFloatingPoint {
@inlinable
static func _vjpMultiply(lhs: Tensor, rhs: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
return (lhs * rhs, { [
lhsShape = lhs.shape,
rhsShape = rhs.shape,
lhsShapeTensor = lhs.shapeTensor,
rhsShapeTensor = rhs.shapeTensor] v in
var lhsGrad = rhs * v
var rhsGrad = lhs * v
if lhsGrad.shape != lhsShape {
lhsGrad = lhsGrad.unbroadcasted(toShape: lhsShapeTensor)
}
if rhsGrad.shape != rhsShape {
rhsGrad = rhsGrad.unbroadcasted(toShape: rhsShapeTensor)
}
return (lhsGrad, rhsGrad)
return (lhs * rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
let lhsGrad = rhs * v
let rhsGrad = lhs * v
let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
})
}
}
Expand Down Expand Up @@ -236,12 +228,12 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {

@inlinable
static func _vjpSubtract(lhs: Tensor, rhs: Scalar) -> (Tensor, (Tensor) -> (Tensor, Scalar)) {
return (lhs - rhs, { v in (v, 0 - v.sum().scalarized()) })
return (lhs - rhs, { v in (v, -v.sum().scalarized()) })
}

@inlinable
static func _vjpSubtract(lhs: Scalar, rhs: Tensor) -> (Tensor, (Tensor) -> (Scalar, Tensor)) {
return (lhs - rhs, { v in (v.sum().scalarized(), 0 - v) })
return (lhs - rhs, { v in (v.sum().scalarized(), -v) })
}

@inlinable
Expand All @@ -256,27 +248,19 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint {

@inlinable
static func _vjpDivide(lhs: Tensor, rhs: Tensor) -> (Tensor, (Tensor) -> (Tensor, Tensor)) {
return (lhs / rhs, { [
lhsShape = lhs.shape,
rhsShape = rhs.shape,
lhsShapeTensor = lhs.shapeTensor,
rhsShapeTensor = rhs.shapeTensor] v in
var lhsGrad = v / rhs
var rhsGrad = (-lhs) / rhs.squared() * v
if lhsGrad.shape != lhsShape {
lhsGrad = lhsGrad.unbroadcasted(toShape: lhsShapeTensor)
}
if rhsGrad.shape != rhsShape {
rhsGrad = rhsGrad.unbroadcasted(toShape: rhsShapeTensor)
}
return (lhsGrad, rhsGrad)
return (lhs / rhs, { [lhsShape = lhs.shapeTensor, rhsShape = rhs.shapeTensor] v in
let lhsGrad = v / rhs
let rhsGrad = -lhs / rhs.squared() * v
let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
})
}

@inlinable
static func _vjpDivide(lhs: Tensor, rhs: Scalar) -> (Tensor, (Tensor) -> (Tensor, Scalar)) {
return (lhs / rhs, { v in
(v / rhs, (v * (0 - lhs) / Tensor(rhs).squared()).sum().scalarized())
(v / rhs, (v * -lhs / Tensor(rhs).squared()).sum().scalarized())
})
}

Expand Down Expand Up @@ -704,15 +688,12 @@ internal func _vjpPow<T: TensorFlowFloatingPoint>(
let value = pow(x, y)
return (value, { v in
let safeX = x.replacing(with: Tensor<T>(onesLike: x), where: x .<= 0)
var gradX = v * y * pow(x, y - 1)
var gradY = value * v * log(safeX)
if gradX.shape != x.shape {
gradX = gradX.unbroadcasted(like: x)
}
if gradY.shape != y.shape {
gradY = gradY.unbroadcasted(like: y)
}
return (gradX, gradY)
let lhsGrad = v * y * pow(x, y - 1)
let rhsGrad = value * v * log(safeX)
let (lhsShape, rhsShape) = (x.shapeTensor, y.shapeTensor)
let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
})
}

Expand Down Expand Up @@ -798,15 +779,12 @@ internal func _vjpMinMaxHelper<T: TensorFlowFloatingPoint>(
seed: Tensor<T>
) -> (Tensor<T>, Tensor<T>) {
let denominator = 1 + Tensor<T>(x .== y)
var gradX = seed * Tensor<T>(x .== originalValue) / denominator
var gradY = seed * Tensor<T>(y .== originalValue) / denominator
if gradX.shape != x.shape {
gradX = gradX.unbroadcasted(like: x)
}
if gradY.shape != y.shape {
gradY = gradY.unbroadcasted(like: y)
}
return (gradX, gradY)
let lhsGrad = seed * Tensor<T>(x .== originalValue) / denominator
let rhsGrad = seed * Tensor<T>(y .== originalValue) / denominator
let (lhsShape, rhsShape) = (x.shapeTensor, y.shapeTensor)
let (lhsAxes, rhsAxes) = Raw.broadcastGradientArgs(s0: lhsShape, s1: rhsShape)
return (lhsGrad.sum(squeezingAxes: lhsAxes).reshaped(toShape: lhsShape),
rhsGrad.sum(squeezingAxes: rhsAxes).reshaped(toShape: rhsShape))
}

//===------------------------------------------------------------------------------------------===//
Expand Down
14 changes: 13 additions & 1 deletion Tests/TensorFlowTests/OperatorTests/MathTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,17 @@ final class MathOperatorTests: XCTestCase {
XCTAssertEqual(0.816997, Double(prediction.scalars[0]), accuracy: 0.0001)
}

func testBroadcastedAddGradient() {
func foo(_ x: Tensor<Float>, _ y: Tensor<Float>) -> Tensor<Float> {
return (x + y).sum()
}
let x = Tensor<Float>(ones: [1, 2, 1, 4])
let y = Tensor<Float>(ones: [4, 1, 3, 1])
let (dx, dy) = gradient(at: x, y, in: foo)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this test case, the use of gradient(at:in:) is not valid because gradient is only mathematically defined for functions that return a scalar. We can either make foo(_:_:) to a sum() or use pullback(at:in:).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what I thought. This was actually the failing case in the other PR and that's why I copied it here but wasn't sure of the semantics of gradient. I'll add a call to sum() and remove the seed broadcasts then. :)

XCTAssertEqual(x.shape, dx.shape)
XCTAssertEqual(y.shape, dy.shape)
}

static var allTests = [
("testReduction", testReduction),
("testArgmax", testArgmax),
Expand All @@ -209,6 +220,7 @@ final class MathOperatorTests: XCTestCase {
("testMultiOpMath", testMultiOpMath),
("testXWPlusB", testXWPlusB),
("testXORInference", testXORInference),
("testMLPClassifierStruct", testMLPClassifierStruct)
("testMLPClassifierStruct", testMLPClassifierStruct),
("testBroadcastedAddGradient", testBroadcastedAddGradient)
]
}