diff --git a/stdlib/public/TensorFlow/Gradients.swift b/stdlib/public/TensorFlow/Gradients.swift index 4430e04354462..dc4f2cfe942b5 100644 --- a/stdlib/public/TensorFlow/Gradients.swift +++ b/stdlib/public/TensorFlow/Gradients.swift @@ -553,12 +553,10 @@ extension Tensor where Scalar : TensorFlowFloatingPoint { } @inlinable - func _vjpExpandingShape( - at shapeIndex: Int - ) -> (Tensor, (Tensor) -> Tensor) { - let value = expandingShape(at: shapeIndex) + func _vjpExpandingShape(at axes: [Int]) -> (Tensor, (Tensor) -> Tensor) { + let value = self.expandingShape(at: axes) return (value, { v in - v.squeezingShape(at: shapeIndex) + v.squeezingShape(at: axes) }) } } @@ -579,7 +577,11 @@ extension Tensor where Scalar : TensorFlowFloatingPoint { squeezingAxes axes: Tensor ) -> (Tensor, (Tensor) -> Tensor) { let value = sum(squeezingAxes: axes) - return (value, { [shape = shapeTensor] in $0.broadcast(toShape: shape) }) + return (value, { [shape = shapeTensor] in + var res = $0 + for i in axes.array.scalars { res = res.expandingShape(at: Int(i)) } + return res.broadcast(toShape: shape) + }) } @inlinable @@ -591,15 +593,6 @@ extension Tensor where Scalar : TensorFlowFloatingPoint { }) } - @inlinable - func _vjpMean(squeezingAxes axes: [Int]) -> (Tensor, (Tensor) -> Tensor) { - let value = mean(squeezingAxes: axes) - return (value, { [shape = shapeTensor, - count = axes.map { shape[$0] }.reduce(1, *)] in - $0.broadcast(toShape: shape) / Tensor(Scalar(count)) - }) - } - @inlinable func _vjpMean( squeezingAxes axes: Tensor @@ -607,7 +600,9 @@ extension Tensor where Scalar : TensorFlowFloatingPoint { let value = mean(squeezingAxes: axes) let count = Raw.gather(params: shapeTensor, indices: axes).product() return (value, { [shape = shapeTensor] in - $0.broadcast(toShape: shape) / Tensor(count) + var res = $0 + for i in axes.array.scalars { res = res.expandingShape(at: Int(i)) } + return res.broadcast(toShape: shape) / Tensor(count) }) } } diff --git a/stdlib/public/TensorFlow/Tensor.swift b/stdlib/public/TensorFlow/Tensor.swift index f0e3972e7771f..9ab73ac850ea9 100644 --- a/stdlib/public/TensorFlow/Tensor.swift +++ b/stdlib/public/TensorFlow/Tensor.swift @@ -706,14 +706,24 @@ public extension Tensor { } /// Returns a shape-expanded `Tensor`, with a dimension of 1 inserted at the - /// specified shape index. + /// specified shape indices. + @inlinable @inline(__always) + @differentiable(wrt: self where Scalar : TensorFlowFloatingPoint) + func expandingShape(at axes: Int...) -> Tensor { + return expandingShape(at: axes) + } + + /// Returns a shape-expanded `Tensor`, with a dimension of 1 inserted at the + /// specified shape indices. @inlinable @inline(__always) @differentiable( wrt: self, vjp: _vjpExpandingShape(at:) where Scalar : TensorFlowFloatingPoint ) - func expandingShape(at shapeIndex: Int) -> Tensor { - return Raw.expandDims(self, dim: Tensor(Int32(shapeIndex))) + func expandingShape(at axes: [Int]) -> Tensor { + var res = self + for i in axes { res = Raw.expandDims(res, dim: Tensor(Int32(i))) } + return res } /// Remove the specified dimensions of size 1 from the shape of a tensor. If diff --git a/test/TensorFlowRuntime/tensor.swift b/test/TensorFlowRuntime/tensor.swift index e67a2c3364d8f..08d63630555cc 100644 --- a/test/TensorFlowRuntime/tensor.swift +++ b/test/TensorFlowRuntime/tensor.swift @@ -674,6 +674,20 @@ TensorTests.testAllBackends("MLPClassifierStruct") { expectPointwiseNearlyEqual([0.816997], prediction.scalars) } +TensorTests.testAllBackends("ExpandingShape") { + // 2 x 3 -> 1 x 2 x 1 x 3 x 1 + let matrix = Tensor([[0, 1, 2], [3, 4, 5]]) + let reshaped = matrix.expandingShape(at: 0,2,4) + + expectEqual([1, 2, 1, 3, 1], reshaped.shape) + expectEqual(Array(0..<6), reshaped.scalars) + + // 1 x 2 x 1 x 3 x 1 -> 2 x 3 + let rereshaped = reshaped.squeezingShape(at: 0,2,4) + expectEqual([2, 3], rereshaped.shape) + expectEqual(Array(0..<6), rereshaped.scalars) +} + TensorTests.testAllBackends("Reshape") { // 2 x 3 -> 1 x 3 x 1 x 2 x 1 let matrix = Tensor([[0, 1, 2], [3, 4, 5]])