diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index a82ec6bf6..a71189183 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -332,6 +332,127 @@ public extension Dense { } } +/// A 1-D convolution layer (e.g. temporal convolution over a time-series). +/// +/// This layer creates a convolution filter that is convolved with the layer input to produce a +/// tensor of outputs. +@_fixed_layout +public struct Conv1D: Layer { + /// The 3-D convolution kernel `[width, inputChannels, outputChannels]`. + public var filter: Tensor + /// The bias vector `[outputChannels]`. + public var bias: Tensor + /// An activation function. + public typealias Activation = @differentiable (Tensor) -> Tensor + /// The element-wise activation function. + @noDerivative public let activation: Activation + /// The stride of the sliding window for temporal dimension. + @noDerivative public let stride: Int32 + /// The padding algorithm for convolution. + @noDerivative public let padding: Padding + + /// Creates a `Conv1D` layer with the specified filter, bias, activation function, stride, and + /// padding. + /// + /// - Parameters: + /// - filter: The 3-D convolution kernel `[width, inputChannels, outputChannels]`. + /// - bias: The bias vector `[outputChannels]`. + /// - activation: The element-wise activation function. + /// - stride: The stride of the sliding window for temporal dimension. + /// - padding: The padding algorithm for convolution. + public init( + filter: Tensor, + bias: Tensor, + activation: @escaping Activation, + stride: Int, + padding: Padding + ) { + self.filter = filter + self.bias = bias + self.activation = activation + self.stride = Int32(stride) + self.padding = padding + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer `[batchCount, width, inputChannels]`. + /// - context: The contextual information for the layer application, e.g. the current learning + /// phase. + /// - Returns: The output `[batchCount, newWidth, outputChannels]`. + @differentiable + public func applied(to input: Tensor, in _: Context) -> Tensor { + let conv2D = input.expandingShape(at: 1).convolved2D( + withFilter: filter.expandingShape(at: 0), strides: (1, 1, stride, 1), padding: padding) + return activation(conv2D.squeezingShape(at: 1) + bias) + } +} + +public extension Conv1D where Scalar.RawSignificand: FixedWidthInteger { + /// Creates a `Conv1D` layer with the specified filter shape, stride, padding, and + /// element-wise activation function. The filter tensor is initialized using Glorot uniform + /// initialization with the specified generator. The bias vector is initialized with zeros. + /// + /// - Parameters: + /// - filterShape: The 3-D shape of the filter, representing + /// `[width, inputChannels, outputChannels]`. + /// - stride: The stride of the sliding window for temporal dimension. + /// - padding: The padding algorithm for convolution. + /// - activation: The element-wise activation function. + /// - generator: The random number generator for initialization. + /// + /// - Note: Use `init(filterShape:stride:padding:activation:seed:)` for faster random + /// initialization. + init( + filterShape: (Int, Int, Int), + stride: Int = 1, + padding: Padding = .valid, + activation: @escaping Activation = identity, + generator: inout G + ) { + let filterTensorShape = TensorShape([ + Int32(filterShape.0), Int32(filterShape.1), Int32(filterShape.2)]) + self.init( + filter: Tensor(glorotUniform: filterTensorShape), + bias: Tensor(zeros: TensorShape([Int32(filterShape.2)])), + activation: activation, + stride: stride, + padding: padding) + } +} + +public extension Conv1D { + /// Creates a `Conv1D` layer with the specified filter shape, strides, padding, and + /// element-wise activation function. The filter tensor is initialized using Glorot uniform + /// initialization with the specified seed. The bias vector is initialized with zeros. + /// + /// - Parameters: + /// - filterShape: The 3-D shape of the filter, representing + /// `[width, inputChannels, outputChannels]`. + /// - stride: The stride of the sliding window for temporal dimension. + /// - padding: The padding algorithm for convolution. + /// - activation: The element-wise activation function. + /// - seed: The random seed for initialization. The default value is random. + init( + filterShape: (Int, Int, Int), + stride: Int = 1, + padding: Padding = .valid, + activation: @escaping Activation = identity, + seed: (Int64, Int64) = (Int64.random(in: Int64.min..: Layer { public typealias Activation = @differentiable (Tensor) -> Tensor /// The element-wise activation function. @noDerivative public let activation: Activation - /// The strides of the sliding window for each dimension of a 4-D input. - /// Strides in non-spatial dimensions must be `1`. + /// The strides of the sliding window for spatial dimensions. @noDerivative public let strides: (Int32, Int32) /// The padding algorithm for convolution. @noDerivative public let padding: Padding @@ -356,11 +476,11 @@ public struct Conv2D: Layer { /// padding. /// /// - Parameters: - /// - filter: The filter. - /// - bias: The bias. - /// - activation: The activation activation. - /// - strides: The strides. - /// - padding: The padding. + /// - filter: The 4-D convolution kernel. + /// - bias: The bias vector. + /// - activation: The element-wise activation function. + /// - strides: The strides of the sliding window for spatial dimensions. + /// - padding: The padding algorithm for convolution. public init( filter: Tensor, bias: Tensor, @@ -396,10 +516,10 @@ public extension Conv2D { /// initialization with the specified generator. The bias vector is initialized with zeros. /// /// - Parameters: - /// - filterShape: The shape of the filter, represented by a tuple of `4` integers. - /// - strides: The strides. - /// - padding: The padding. - /// - activation: The activation function. + /// - filterShape: The shape of the 4-D convolution kernel. + /// - strides: The strides of the sliding window for spatial dimensions. + /// - padding: The padding algorithm for convolution. + /// - activation: The element-wise activation function. /// - generator: The random number generator for initialization. /// /// - Note: Use `init(filterShape:strides:padding:activation:seed:)` for faster random @@ -429,14 +549,11 @@ public extension Conv2D { /// initialization with the specified seed. The bias vector is initialized with zeros. /// /// - Parameters: - /// - filterShape: The shape of the filter, represented by a tuple of `4` integers. - /// - strides: The strides. - /// - padding: The padding. - /// - activation: The activation function. + /// - filterShape: The shape of the 4-D convolution kernel. + /// - strides: The strides of the sliding window for spatial dimensions. + /// - padding: The padding algorithm for convolution. + /// - activation: The element-wise activation function. /// - seed: The random seed for initialization. The default value is random. - /// - /// - Note: Use `init(filterShape:strides:padding:activation:seed:)` for faster random - /// initialization. init( filterShape: (Int, Int, Int, Int), strides: (Int, Int) = (1, 1), @@ -449,11 +566,11 @@ public extension Conv2D { Int32(filterShape.0), Int32(filterShape.1), Int32(filterShape.2), Int32(filterShape.3)]) self.init( - filter: Tensor(glorotUniform: filterTensorShape, seed: seed), - bias: Tensor(zeros: TensorShape([Int32(filterShape.3)])), - activation: activation, - strides: (Int32(strides.0), Int32(strides.1)), - padding: padding) + filter: Tensor(glorotUniform: filterTensorShape, seed: seed), + bias: Tensor(zeros: TensorShape([Int32(filterShape.3)])), + activation: activation, + strides: (Int32(strides.0), Int32(strides.1)), + padding: padding) } } @@ -582,6 +699,47 @@ public struct BatchNorm: Layer { } } +/// A max pooling layer for temporal data. +@_fixed_layout +public struct MaxPool1D: Layer { + /// The size of the sliding reduction window for pooling. + @noDerivative let poolSize: Int32 + /// The stride of the sliding window for temporal dimension. + @noDerivative let stride: Int32 + /// The padding algorithm for pooling. + @noDerivative let padding: Padding + + /// Creates a max pooling layer. + /// + /// - Parameters: + /// - poolSize: The size of the sliding reduction window for pooling. + /// - stride: The stride of the sliding window for temporal dimension. + /// - padding: The padding algorithm for pooling. + public init( + poolSize: Int, + stride: Int, + padding: Padding + ) { + self.poolSize = Int32(poolSize) + self.stride = Int32(stride) + self.padding = padding + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - context: The contextual information for the layer application, e.g. the current learning + /// phase. + /// - Returns: The output. + @differentiable + public func applied(to input: Tensor, in _: Context) -> Tensor { + return input.expandingShape(at: 1).maxPooled( + kernelSize: (1, 1, poolSize, 1), strides: (1, 1, stride, 1), padding: padding + ).squeezingShape(at: 1) + } +} + /// A max pooling layer for spatial data. @_fixed_layout public struct MaxPool2D: Layer { @@ -628,7 +786,48 @@ public struct MaxPool2D: Layer { @differentiable public func applied(to input: Tensor, in _: Context) -> Tensor { return input.maxPooled( - kernelSize: poolSize, strides: strides, padding: padding) + kernelSize: poolSize, strides: strides, padding: padding) + } +} + +/// An average pooling layer for temporal data. +@_fixed_layout +public struct AvgPool1D: Layer { + /// The size of the sliding reduction window for pooling. + @noDerivative let poolSize: Int32 + /// The stride of the sliding window for temporal dimension. + @noDerivative let stride: Int32 + /// The padding algorithm for pooling. + @noDerivative let padding: Padding + + /// Creates an average pooling layer. + /// + /// - Parameters: + /// - poolSize: The size of the sliding reduction window for pooling. + /// - stride: The stride of the sliding window for temporal dimension. + /// - padding: The padding algorithm for pooling. + public init( + poolSize: Int, + stride: Int, + padding: Padding + ) { + self.poolSize = Int32(poolSize) + self.stride = Int32(stride) + self.padding = padding + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - context: The contextual information for the layer application, e.g. the current learning + /// phase. + /// - Returns: The output. + @differentiable + public func applied(to input: Tensor, in _: Context) -> Tensor { + return input.expandingShape(at: 1).averagePooled( + kernelSize: (1, 1, poolSize, 1), strides: (1, 1, stride, 1), padding: padding + ).squeezingShape(at: 1) } } diff --git a/Tests/DeepLearningTests/LayerTests.swift b/Tests/DeepLearningTests/LayerTests.swift new file mode 100644 index 000000000..a0f75703b --- /dev/null +++ b/Tests/DeepLearningTests/LayerTests.swift @@ -0,0 +1,48 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest +@testable import DeepLearning + +final class LayerTests: XCTestCase { + func testConv1D() { + let filter = Tensor(ones: [3, 1, 2]) * Tensor([[[0.33333333, 1]]]) + let bias = Tensor([0, 1]) + let layer = Conv1D(filter: filter, bias: bias, activation: identity, stride: 1, padding: .valid) + let input = Tensor([[0, 1, 2, 3, 4], [10, 11, 12, 13, 14]]).expandingShape(at: 2) + let output = layer.inferring(from: input) + let expected = Tensor([[[1, 4], [2, 7], [3, 10]], [[11, 34], [12, 37], [13, 40]]]) + XCTAssertEqual(round(output), expected) + } + + func testMaxPool1D() { + let layer = MaxPool1D(poolSize: 3, stride: 1, padding: .valid) + let input = Tensor([[0, 1, 2, 3, 4], [10, 11, 12, 13, 14]]).expandingShape(at: 2) + let output = layer.inferring(from: input) + let expected = Tensor([[[2], [3], [4]], [[12], [13], [14]]]) + XCTAssertEqual(round(output), expected) + } + + func testAvgPool1D() { + let layer = AvgPool1D(poolSize: 3, stride: 1, padding: .valid) + let input = Tensor([[0, 1, 2, 3, 4], [10, 11, 12, 13, 14]]).expandingShape(at: 2) + let output = layer.inferring(from: input) + let expected = Tensor([[[1], [2], [3]], [[11], [12], [13]]]) + XCTAssertEqual(round(output), expected) + } + + static var allTests = [ + ("testConv1D", testConv1D), ("testMaxPool1D", testMaxPool1D), ("testAvgPool1D", testAvgPool1D) + ] +} diff --git a/Tests/DeepLearningTests/XCTestManifests.swift b/Tests/DeepLearningTests/XCTestManifests.swift index a731b951c..96a9048a5 100644 --- a/Tests/DeepLearningTests/XCTestManifests.swift +++ b/Tests/DeepLearningTests/XCTestManifests.swift @@ -21,6 +21,7 @@ public func allTests() -> [XCTestCaseEntry] { testCase(PRNGTests.allTests), testCase(TrivialModelTests.allTests), testCase(SequentialTests.allTests), + testCase(LayerTests.allTests), ] } #endif