diff --git a/stdlib/public/TensorFlow/Gradients.swift b/stdlib/public/TensorFlow/Gradients.swift index 794af47ba6021..c376b1dfd7de3 100644 --- a/stdlib/public/TensorFlow/Gradients.swift +++ b/stdlib/public/TensorFlow/Gradients.swift @@ -645,16 +645,14 @@ extension Tensor where Scalar : TensorFlowFloatingPoint { func _vjpBroadcast( toShape shape: Tensor ) -> (Tensor, (Tensor) -> Tensor) { - return (broadcast(toShape: shape), { [origShape = self.shapeTensor] v in + return (broadcast(toShape: shape), { [origShape = shapeTensor] v in v.unbroadcast(toShape: origShape) }) } @inlinable - func _vjpUnbroadcast( - toShape shape: Tensor - ) -> (Tensor, (Tensor) -> Tensor) { - return (unbroadcast(toShape: shape), { [origShape = self.shapeTensor] v in + func _vjpUnbroadcast(to shape: TensorShape) -> (Tensor, (Tensor) -> Tensor) { + return (unbroadcast(to: shape), { [origShape = shapeTensor] v in v.broadcast(toShape: origShape) }) } diff --git a/stdlib/public/TensorFlow/Ops.swift b/stdlib/public/TensorFlow/Ops.swift index 11a84cc6ad261..47ae64e302ed6 100644 --- a/stdlib/public/TensorFlow/Ops.swift +++ b/stdlib/public/TensorFlow/Ops.swift @@ -1601,7 +1601,7 @@ public extension Tensor { public extension Tensor { @inlinable @differentiable(wrt: self, vjp: _vjpBroadcast(toShape:) - where Scalar : TensorFlowFloatingPoint) + where Scalar : TensorFlowFloatingPoint) func broadcast(toShape shape: Tensor) -> Tensor { return Raw.broadcastTo(self, shape: shape) } @@ -1609,14 +1609,14 @@ public extension Tensor { @inlinable @differentiable(wrt: self where Scalar : TensorFlowFloatingPoint) func broadcast(to shape: TensorShape) -> Tensor { - return broadcast(toShape: Tensor({ shape.dimensions.map(Int32.init) }())) + return broadcast( + toShape: Tensor({ shape.dimensions.map(Int32.init) }())) } /// Broadcast to the same shape as the specified `Tensor`. /// - Precondition: The specified shape must be compatible for broadcasting. @inlinable - @differentiable(wrt: self - where Scalar : TensorFlowFloatingPoint) + @differentiable(wrt: self where Scalar : TensorFlowFloatingPoint) func broadcast(like other: Tensor) -> Tensor { return broadcast(toShape: other.shapeTensor) } @@ -1624,17 +1624,13 @@ public extension Tensor { public extension Tensor where Scalar : Numeric { @inlinable - @differentiable(wrt: self, vjp: _vjpUnbroadcast(toShape:) - where Scalar : TensorFlowFloatingPoint) + @differentiable(wrt: self where Scalar : TensorFlowFloatingPoint) func unbroadcast(toShape otherShape: Tensor) -> Tensor { - let rankDiff = (rankTensor - otherShape.scalarCountTensor).rankLifted() - let ones: Tensor = Raw.fill(dims: rankDiff, value: Tensor(1)) - let paddedShape = ones ++ otherShape - let nonEqualIndices = paddedShape .!= shapeTensor - let broadcastIndices = Raw.where_(nonEqualIndices).flattened() - let unbroadcasted: Tensor = Raw.sum( - self, reductionIndices: Tensor(broadcastIndices), keepDims: false) - return Raw.reshape(unbroadcasted, shape: otherShape) + // TODO: Simplify this once differentiating control flow is supported. + return unbroadcast(to: { + precondition(otherShape.rank == 1) + return TensorShape(otherShape.scalars.map(Int.init)) + }()) } @inlinable @@ -1644,9 +1640,31 @@ public extension Tensor where Scalar : Numeric { } @inlinable - @differentiable(wrt: self where Scalar : TensorFlowFloatingPoint) + @differentiable(wrt: self, vjp: _vjpUnbroadcast(to:) + where Scalar : TensorFlowFloatingPoint) func unbroadcast(to shape: TensorShape) -> Tensor { - return unbroadcast(toShape: Tensor({ shape.dimensions.map(Int32.init) }())) + let dimensions = self.shape.dimensions + var otherDimensions = shape.dimensions + let rankDifference = dimensions.count - otherDimensions.count + precondition(rankDifference >= 0, """ + The rank of 'self' must be greater than or equal to the number of \ + dimensions in the destination shape + """) + if rankDifference > 0 { + otherDimensions.insert( + contentsOf: repeatElement(1, count: rankDifference), + at: 0 + ) + } + assert(dimensions.count == otherDimensions.count) + var axes: [Int] = [] + axes.reserveCapacity(dimensions.count) + for (i, (dim, otherDim)) in zip(dimensions, otherDimensions).enumerated() { + if dim == otherDim { continue } + if otherDim == 1 { axes.append(i); continue } + preconditionFailure("Cannot unbroadcast \(self.shape) to \(shape)") + } + return sum(alongAxes: axes).reshaped(to: shape) } @inlinable