diff --git a/stdlib/public/TensorFlow/Gradients.swift b/stdlib/public/TensorFlow/Gradients.swift index dc4f2cfe942b5..fee63b347af06 100644 --- a/stdlib/public/TensorFlow/Gradients.swift +++ b/stdlib/public/TensorFlow/Gradients.swift @@ -577,10 +577,9 @@ extension Tensor where Scalar : TensorFlowFloatingPoint { squeezingAxes axes: Tensor ) -> (Tensor, (Tensor) -> Tensor) { let value = sum(squeezingAxes: axes) - return (value, { [shape = shapeTensor] in - var res = $0 - for i in axes.array.scalars { res = res.expandingShape(at: Int(i)) } - return res.broadcast(toShape: shape) + return (value, { [shape = shapeTensor] v in + let unsqueezed = v.expandingShape(at: axes.scalars.map { Int($0) }) + return unsqueezed.broadcast(toShape: shape) }) } @@ -599,10 +598,9 @@ extension Tensor where Scalar : TensorFlowFloatingPoint { ) -> (Tensor, (Tensor) -> Tensor) { let value = mean(squeezingAxes: axes) let count = Raw.gather(params: shapeTensor, indices: axes).product() - return (value, { [shape = shapeTensor] in - var res = $0 - for i in axes.array.scalars { res = res.expandingShape(at: Int(i)) } - return res.broadcast(toShape: shape) / Tensor(count) + return (value, { [shape = shapeTensor] v in + let unsqueezed = v.expandingShape(at: axes.scalars.map { Int($0) }) + return unsqueezed.broadcast(toShape: shape) / Tensor(count) }) } }