diff --git a/stdlib/public/TensorFlow/Gradients.swift b/stdlib/public/TensorFlow/Gradients.swift index 4430e04354462..ebc5d0e4d6ef3 100644 --- a/stdlib/public/TensorFlow/Gradients.swift +++ b/stdlib/public/TensorFlow/Gradients.swift @@ -579,7 +579,11 @@ extension Tensor where Scalar : TensorFlowFloatingPoint { squeezingAxes axes: Tensor ) -> (Tensor, (Tensor) -> Tensor) { let value = sum(squeezingAxes: axes) - return (value, { [shape = shapeTensor] in $0.broadcast(toShape: shape) }) + return (value, { [shape = shapeTensor] in + var res = $0 + for i in axes.array.scalars { res = res.expandingShape(at: Int(i)) } + return res.broadcast(toShape: shape) + }) } @inlinable @@ -591,15 +595,6 @@ extension Tensor where Scalar : TensorFlowFloatingPoint { }) } - @inlinable - func _vjpMean(squeezingAxes axes: [Int]) -> (Tensor, (Tensor) -> Tensor) { - let value = mean(squeezingAxes: axes) - return (value, { [shape = shapeTensor, - count = axes.map { shape[$0] }.reduce(1, *)] in - $0.broadcast(toShape: shape) / Tensor(Scalar(count)) - }) - } - @inlinable func _vjpMean( squeezingAxes axes: Tensor @@ -607,7 +602,9 @@ extension Tensor where Scalar : TensorFlowFloatingPoint { let value = mean(squeezingAxes: axes) let count = Raw.gather(params: shapeTensor, indices: axes).product() return (value, { [shape = shapeTensor] in - $0.broadcast(toShape: shape) / Tensor(count) + var res = $0 + for i in axes.array.scalars { res = res.expandingShape(at: Int(i)) } + return res.broadcast(toShape: shape) / Tensor(count) }) } } diff --git a/test/TensorFlowRuntime/tensor_autodiff_runtime.swift b/test/TensorFlowRuntime/tensor_autodiff_runtime.swift index a17dc3d3d034c..69cbe26b88284 100644 --- a/test/TensorFlowRuntime/tensor_autodiff_runtime.swift +++ b/test/TensorFlowRuntime/tensor_autodiff_runtime.swift @@ -98,38 +98,39 @@ TensorADTests.testAllBackends("Abs") { TensorADTests.testAllBackends("sum") { let input = Tensor(repeating: 42, shape: [2, 2]) let sumPullbackScalar = pullback(at: input) { (a: Tensor) in a.sum() } + let sumPullbackSqueezingAxes = pullback(at: input) { (a: Tensor) in a.sum(squeezingAxes: 0, 1) } let sumPullbackAlongAxes = pullback(at: input) { (a: Tensor) in a.sum(alongAxes: 0, 1) } let expected = Tensor(ones: [2, 2]) expectEqual(expected, sumPullbackScalar(Tensor(1))) - // expectEqual(expected, sumPullbackSqueezingAxes(Tensor(1))) + expectEqual(expected, sumPullbackSqueezingAxes(Tensor(1))) expectEqual(expected, sumPullbackAlongAxes(Tensor(1))) expectEqual(expected * 3, sumPullbackScalar(Tensor(3))) - // expectEqual(expected * 3, sumPullbackSqueezingAxes(Tensor(3))) + expectEqual(expected * 3, sumPullbackSqueezingAxes(Tensor(3))) expectEqual(expected * 3, sumPullbackAlongAxes(Tensor(3))) } TensorADTests.testAllBackends("mean") { let meanGradScalar = gradient { (a: Tensor) in a.mean() } - // let meanGradSqueezingAxes = gradient { (a: Tensor) in a.mean(squeezingAxes: 0, 1) } + let meanGradSqueezingAxes = gradient { (a: Tensor) in a.mean(squeezingAxes: 0, 1) } let meanGradAlongAxes = gradient { (a: Tensor) in a.mean(alongAxes: 0, 1) } let input = Tensor(ones: [2, 2]) let expected = Tensor(repeating: 0.25, shape: [2, 2]) expectEqual(expected, meanGradScalar(input)) - // expectEqual(expected, meanGradSqueezingAxes(input)) + expectEqual(expected, meanGradSqueezingAxes(input)) expectEqual(expected, meanGradAlongAxes(input)) } TensorADTests.testAllBackends("variance") { let varianceGradScalar = gradient { (a: Tensor) in a.variance() } - // let varianceGradSqueezingAxes = gradient { (a: Tensor) in a.variance(squeezingAxes: 0, 1) } + let varianceGradSqueezingAxes = gradient { (a: Tensor) in a.variance(squeezingAxes: 0, 1) } let varianceGradAlongAxes = gradient { (a: Tensor) in a.variance(alongAxes: 0, 1) } let input: Tensor = [[1, 2], [3, 4]] let expected: Tensor = [[-0.75, -0.25], [0.25, 0.75]] expectEqual(expected, varianceGradScalar(input)) - // expectEqual(expected, varianceGradSqueezingAxes(input)) + expectEqual(expected, varianceGradSqueezingAxes(input)) expectEqual(expected, varianceGradAlongAxes(input)) }