diff --git a/Sources/TensorFlow/Operators/Basic.swift b/Sources/TensorFlow/Operators/Basic.swift index 0288c6294..3d151fede 100644 --- a/Sources/TensorFlow/Operators/Basic.swift +++ b/Sources/TensorFlow/Operators/Basic.swift @@ -769,14 +769,38 @@ extension Tensor where Scalar: TensorFlowFloatingPoint { //===------------------------------------------------------------------------------------------===// public extension Tensor where Scalar: Numeric { - /// Returns a padded tensor according to the specified padding sizes. + /// A mode that dictates how a tensor is padded. + enum PaddingMode { + /// Pads with constant value. + case constant(Scalar) + /// Mirrors values along padding dimensions, excluding the edge value. + case reflect + /// Mirrors values along padding dimensions, including the edge value. + case symmetric + } + + /// Returns a tensor padded with constant according to the specified padding sizes. @inlinable - @differentiable(wrt: self, vjp: _vjpPadded(forSizes:with:) where Scalar: TensorFlowFloatingPoint) + @differentiable(wrt: self where Scalar: TensorFlowFloatingPoint) func padded(forSizes sizes: [(before: Int, after: Int)], with value: Scalar = 0) -> Tensor { + padded(forSizes: sizes, mode: .constant(value)) + } + + /// Returns a padded tensor according to the specified padding sizes and mode. + @inlinable + @differentiable(wrt: self, vjp: _vjpPadded(forSizes:mode:) where Scalar: TensorFlowFloatingPoint) + func padded(forSizes sizes: [(before: Int, after: Int)], mode: PaddingMode) -> Tensor { let paddings = Tensor( shape: [sizes.count, 2], scalars: sizes.flatMap { [Int32($0.before), Int32($0.after)] }) - return Raw.padV2(self, paddings: paddings, constantValues: Tensor(value)) + switch mode { + case .constant(let constantValue): + return Raw.padV2(self, paddings: paddings, constantValues: Tensor(constantValue)) + case .reflect: + return Raw.mirrorPad(self, paddings: paddings, mode: .reflect) + case .symmetric: + return Raw.mirrorPad(self, paddings: paddings, mode: .symmetric) + } } } @@ -784,18 +808,26 @@ internal extension Tensor where Scalar: TensorFlowFloatingPoint { @inlinable func _vjpPadded( forSizes sizes: [(before: Int, after: Int)], - with value: Scalar + mode: PaddingMode ) -> (Tensor, (Tensor) -> Tensor) { - let result = padded(forSizes: sizes, with: value) + let result = padded(forSizes: sizes, mode: mode) return (result, { [rank = rankTensor, shape = shapeTensor] v in let paddings = Tensor( shape: [sizes.count, 2], scalars: sizes.flatMap { [Int32($0.before), Int32($0.after)] }) - let padBefore = Raw.slice(paddings, - begin: Tensor([0, 0]), - size: Tensor(stacking: [rank, Tensor(1)])) - let begin = padBefore.reshaped(to: [-1]) - return Raw.slice(v, begin: begin, size: shape) + switch mode { + case .constant: + let padBefore = Raw.slice( + paddings, + begin: Tensor([0, 0]), + size: Tensor(stacking: [rank, Tensor(1)])) + let begin = padBefore.reshaped(to: [-1]) + return v.slice(lowerBounds: begin, sizes: shape) + case .reflect: + return Raw.mirrorPadGrad(v, paddings: paddings, mode: .reflect) + case .symmetric: + return Raw.mirrorPadGrad(v, paddings: paddings, mode: .symmetric) + } }) } } diff --git a/Tests/TensorFlowTests/OperatorTests/BasicTests.swift b/Tests/TensorFlowTests/OperatorTests/BasicTests.swift index bb63a2f27..e66fcc387 100644 --- a/Tests/TensorFlowTests/OperatorTests/BasicTests.swift +++ b/Tests/TensorFlowTests/OperatorTests/BasicTests.swift @@ -40,6 +40,39 @@ final class BasicOperatorTests: XCTestCase { XCTAssertEqual(paddedTensor, target) } + func testPaddedConstant() { + let x = Tensor(ones: [2, 2]) + let target = Tensor([[3, 3, 3], [1, 1, 3], [1, 1, 3]]) + let paddedTensor = x.padded(forSizes: [(1, 0), (0, 1)], mode: .constant(3.0)) + XCTAssertEqual(paddedTensor, target) + } + + func testPaddedReflect() { + let x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + let target = Tensor([ + [7, 8, 9, 8, 7], + [4, 5, 6, 5, 4], + [1, 2, 3, 2, 1], + [4, 5, 6, 5, 4], + [7, 8, 9, 8, 7] + ]) + let paddedTensor = x.padded(forSizes: [(2, 0), (0, 2)], mode: .reflect) + XCTAssertEqual(paddedTensor, target) + } + + func testPaddedSymmetric() { + let x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + let target = Tensor([ + [4, 5, 6, 6, 5], + [1, 2, 3, 3, 2], + [1, 2, 3, 3, 2], + [4, 5, 6, 6, 5], + [7, 8, 9, 9, 8] + ]) + let paddedTensor = x.padded(forSizes: [(2, 0), (0, 2)], mode: .symmetric) + XCTAssertEqual(paddedTensor, target) + } + func testVJPPadded() { let x = Tensor(ones: [3, 2]) let target = Tensor([[2, 2], [2, 2], [2, 2]]) @@ -50,6 +83,36 @@ final class BasicOperatorTests: XCTestCase { XCTAssertEqual(grads, target) } + func testVJPPaddedConstant() { + let x = Tensor(ones: [3, 2]) + let target = Tensor([[2, 2], [2, 2], [2, 2]]) + let grads = x.gradient { a -> Tensor in + let paddedTensor = a.padded(forSizes: [(1, 0), (0, 1)], mode: .constant(3.0)) + return (paddedTensor * paddedTensor).sum() + } + XCTAssertEqual(grads, target) + } + + func testVJPPaddedReflect() { + let x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + let target = Tensor([[4, 8, 6], [32, 40, 24], [56, 64, 36]]) + let grads = x.gradient { a -> Tensor in + let paddedTensor = a.padded(forSizes: [(2, 0), (0, 2)], mode: .reflect) + return (paddedTensor * paddedTensor).sum() + } + XCTAssertEqual(grads, target) + } + + func testVJPPaddedSymmetric() { + let x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + let target = Tensor([[4, 16, 24], [16, 40, 48], [14, 32, 36]]) + let grads = x.gradient { a -> Tensor in + let paddedTensor = a.padded(forSizes: [(2, 0), (0, 2)], mode: .symmetric) + return (paddedTensor * paddedTensor).sum() + } + XCTAssertEqual(grads, target) + } + func testElementIndexing() { // NOTE: cannot test multiple `Tensor.shape` or `Tensor.scalars` directly // until send and receive are implemented (without writing a bunch of mini @@ -599,7 +662,13 @@ final class BasicOperatorTests: XCTestCase { ("testGathering", testGathering), ("testBatchGathering", testBatchGathering), ("testPadded", testPadded), + ("testPaddedConstant", testPaddedConstant), + ("testPaddedReflect", testPaddedReflect), + ("testPaddedSymmetric", testPaddedSymmetric), ("testVJPPadded", testVJPPadded), + ("testVJPPaddedConstant", testVJPPaddedConstant), + ("testVJPPaddedReflect", testVJPPaddedReflect), + ("testVJPPaddedSymmetric", testVJPPaddedSymmetric), ("testElementIndexing", testElementIndexing), ("testElementIndexingAssignment", testElementIndexingAssignment), ("testNestedElementIndexing", testNestedElementIndexing),