Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 11 additions & 16 deletions stdlib/public/TensorFlow/Gradients.swift
Original file line number Diff line number Diff line change
Expand Up @@ -553,12 +553,10 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
}

@inlinable
func _vjpExpandingShape(
at shapeIndex: Int
) -> (Tensor, (Tensor) -> Tensor) {
let value = expandingShape(at: shapeIndex)
func _vjpExpandingShape(at axes: [Int]) -> (Tensor, (Tensor) -> Tensor) {
let value = self.expandingShape(at: axes)
return (value, { v in
v.squeezingShape(at: shapeIndex)
v.squeezingShape(at: axes)
})
}
}
Expand All @@ -579,7 +577,11 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
squeezingAxes axes: Tensor<Int32>
) -> (Tensor, (Tensor) -> Tensor) {
let value = sum(squeezingAxes: axes)
return (value, { [shape = shapeTensor] in $0.broadcast(toShape: shape) })
return (value, { [shape = shapeTensor] in
var res = $0
for i in axes.array.scalars { res = res.expandingShape(at: Int(i)) }
return res.broadcast(toShape: shape)
})
}

@inlinable
Expand All @@ -591,23 +593,16 @@ extension Tensor where Scalar : TensorFlowFloatingPoint {
})
}

@inlinable
func _vjpMean(squeezingAxes axes: [Int]) -> (Tensor, (Tensor) -> Tensor) {
let value = mean(squeezingAxes: axes)
return (value, { [shape = shapeTensor,
count = axes.map { shape[$0] }.reduce(1, *)] in
$0.broadcast(toShape: shape) / Tensor(Scalar(count))
})
}

@inlinable
func _vjpMean(
squeezingAxes axes: Tensor<Int32>
) -> (Tensor, (Tensor) -> Tensor) {
let value = mean(squeezingAxes: axes)
let count = Raw.gather(params: shapeTensor, indices: axes).product()
return (value, { [shape = shapeTensor] in
$0.broadcast(toShape: shape) / Tensor(count)
var res = $0
for i in axes.array.scalars { res = res.expandingShape(at: Int(i)) }
return res.broadcast(toShape: shape) / Tensor(count)
})
}
}
Expand Down
16 changes: 13 additions & 3 deletions stdlib/public/TensorFlow/Tensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -706,14 +706,24 @@ public extension Tensor {
}

/// Returns a shape-expanded `Tensor`, with a dimension of 1 inserted at the
/// specified shape index.
/// specified shape indices.
@inlinable @inline(__always)
@differentiable(wrt: self where Scalar : TensorFlowFloatingPoint)
func expandingShape(at axes: Int...) -> Tensor {
return expandingShape(at: axes)
}

/// Returns a shape-expanded `Tensor`, with a dimension of 1 inserted at the
/// specified shape indices.
@inlinable @inline(__always)
@differentiable(
wrt: self, vjp: _vjpExpandingShape(at:)
where Scalar : TensorFlowFloatingPoint
)
func expandingShape(at shapeIndex: Int) -> Tensor {
return Raw.expandDims(self, dim: Tensor<Int32>(Int32(shapeIndex)))
func expandingShape(at axes: [Int]) -> Tensor {
var res = self
for i in axes { res = Raw.expandDims(res, dim: Tensor<Int32>(Int32(i))) }
return res
}

/// Remove the specified dimensions of size 1 from the shape of a tensor. If
Expand Down
14 changes: 14 additions & 0 deletions test/TensorFlowRuntime/tensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,20 @@ TensorTests.testAllBackends("MLPClassifierStruct") {
expectPointwiseNearlyEqual([0.816997], prediction.scalars)
}

TensorTests.testAllBackends("ExpandingShape") {
// 2 x 3 -> 1 x 2 x 1 x 3 x 1
let matrix = Tensor<Int32>([[0, 1, 2], [3, 4, 5]])
let reshaped = matrix.expandingShape(at: 0,2,4)

expectEqual([1, 2, 1, 3, 1], reshaped.shape)
expectEqual(Array(0..<6), reshaped.scalars)

// 1 x 2 x 1 x 3 x 1 -> 2 x 3
let rereshaped = reshaped.squeezingShape(at: 0,2,4)
expectEqual([2, 3], rereshaped.shape)
expectEqual(Array(0..<6), rereshaped.scalars)
}

TensorTests.testAllBackends("Reshape") {
// 2 x 3 -> 1 x 3 x 1 x 2 x 1
let matrix = Tensor<Int32>([[0, 1, 2], [3, 4, 5]])
Expand Down