From 82797fdb1e4dd86e04d2ab7ffda29ff9c0613ea0 Mon Sep 17 00:00:00 2001 From: Dave Fernandes Date: Sat, 16 Mar 2019 18:03:42 -0400 Subject: [PATCH 1/7] Added Conv1d, MaxPool1D and AvgPool1D layers --- Sources/DeepLearning/Layer.swift | 207 ++++++++++++++++++++++++++++++- 1 file changed, 202 insertions(+), 5 deletions(-) diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index a82ec6bf6..11dc7600b 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -332,6 +332,125 @@ public extension Dense { } } +/// A 1-D convolution layer (e.g. temporal convolution over a time-series). +/// +/// This layer creates a convolution filter that is convolved with the layer input to produce a +/// tensor of outputs. +@_fixed_layout +public struct Conv1D: Layer { + /// The 3-D convolution kernel. + public var filter: Tensor + /// The bias vector. + public var bias: Tensor + /// An activation function. + public typealias Activation = @differentiable (Tensor) -> Tensor + /// The element-wise activation function. + @noDerivative public let activation: Activation + /// The stride of the sliding window for temporal dimension. + @noDerivative public let stride: Int32 + /// The padding algorithm for convolution. + @noDerivative public let padding: Padding + + /// Creates a `Conv1D` layer with the specified filter, bias, activation function, stride, and + /// padding. + /// + /// - Parameters: + /// - filter: The filter. + /// - bias: The bias. + /// - activation: The activation activation. + /// - stride: The stride. + /// - padding: The padding. + public init( + filter: Tensor, + bias: Tensor, + activation: @escaping Activation, + stride: Int, + padding: Padding + ) { + self.filter = filter + self.bias = bias + self.activation = activation + self.stride = Int32(stride) + self.padding = padding + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - context: The contextual information for the layer application, e.g. the current learning + /// phase. + /// - Returns: The output. + @differentiable + public func applied(to input: Tensor, in _: Context) -> Tensor { + let conv2D = input.convolved2D(withFilter: filter.expandingShape(at: 1), + strides: (1, 1, stride, 1), padding: padding) + return activation(conv2D.squeezingShape(at: 1) + bias) + } +} + +public extension Conv1D where Scalar.RawSignificand: FixedWidthInteger { + /// Creates a `Conv1D` layer with the specified filter shape, stride, padding, and + /// element-wise activation function. The filter tensor is initialized using Glorot uniform + /// initialization with the specified generator. The bias vector is initialized with zeros. + /// + /// - Parameters: + /// - filterShape: The shape of the filter, represented by a tuple of `3` integers. + /// - stride: The stride. + /// - padding: The padding. + /// - activation: The activation function. + /// - generator: The random number generator for initialization. + /// + /// - Note: Use `init(filterShape:stride:padding:activation:seed:)` for faster random + /// initialization. + init( + filterShape: (Int, Int, Int), + stride: Int = 1, + padding: Padding = .valid, + activation: @escaping Activation = identity, + generator: inout G + ) { + let filterTensorShape = TensorShape([ + Int32(filterShape.0), Int32(filterShape.1), Int32(filterShape.2)]) + self.init( + filter: Tensor(glorotUniform: filterTensorShape), + bias: Tensor(zeros: TensorShape([Int32(filterShape.2)])), + activation: activation, + stride: stride, + padding: padding) + } +} + +public extension Conv1D { + /// Creates a `Conv2D` layer with the specified filter shape, strides, padding, and + /// element-wise activation function. The filter tensor is initialized using Glorot uniform + /// initialization with the specified seed. The bias vector is initialized with zeros. + /// + /// - Parameters: + /// - filterShape: The shape of the filter, represented by a tuple of `4` integers. + /// - strides: The strides. + /// - padding: The padding. + /// - activation: The activation function. + /// - seed: The random seed for initialization. The default value is random. + init( + filterShape: (Int, Int, Int), + stride: Int = 1, + padding: Padding = .valid, + activation: @escaping Activation = identity, + seed: (Int64, Int64) = (Int64.random(in: Int64.min..: Layer { public typealias Activation = @differentiable (Tensor) -> Tensor /// The element-wise activation function. @noDerivative public let activation: Activation - /// The strides of the sliding window for each dimension of a 4-D input. - /// Strides in non-spatial dimensions must be `1`. + /// The strides of the sliding window for spatial dimensions. @noDerivative public let strides: (Int32, Int32) /// The padding algorithm for convolution. @noDerivative public let padding: Padding @@ -434,9 +552,6 @@ public extension Conv2D { /// - padding: The padding. /// - activation: The activation function. /// - seed: The random seed for initialization. The default value is random. - /// - /// - Note: Use `init(filterShape:strides:padding:activation:seed:)` for faster random - /// initialization. init( filterShape: (Int, Int, Int, Int), strides: (Int, Int) = (1, 1), @@ -582,6 +697,47 @@ public struct BatchNorm: Layer { } } +/// A max pooling layer for temporal data. +@_fixed_layout +public struct MaxPool1D: Layer { + /// The size of the sliding reduction window for pooling. + @noDerivative let poolSize: Int32 + /// The stride of the sliding window for temporal dimension. + @noDerivative let stride: Int32 + /// The padding algorithm for pooling. + @noDerivative let padding: Padding + + /// Creates a max pooling layer. + /// + /// - Parameters: + /// - poolSize: Factor by which to downscale. + /// - stride: The stride. + /// - padding: The padding. + public init( + poolSize: Int, + stride: Int, + padding: Padding + ) { + self.poolSize = Int32(poolSize) + self.stride = Int32(stride) + self.padding = padding + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - context: The contextual information for the layer application, e.g. the current learning + /// phase. + /// - Returns: The output. + @differentiable + public func applied(to input: Tensor, in _: Context) -> Tensor { + return input.expandingShape(at: 1).maxPooled( + kernelSize: (1, 1, poolSize, 1), strides: (1, 1, stride, 1), padding: padding + ).squeezingShape(at: 1) + } +} + /// A max pooling layer for spatial data. @_fixed_layout public struct MaxPool2D: Layer { @@ -632,6 +788,47 @@ public struct MaxPool2D: Layer { } } +/// An average pooling layer for temporal data. +@_fixed_layout +public struct AvgPool1D: Layer { + /// The size of the sliding reduction window for pooling. + @noDerivative let poolSize: Int32 + /// The stride of the sliding window for temporal dimension. + @noDerivative let stride: Int32 + /// The padding algorithm for pooling. + @noDerivative let padding: Padding + + /// Creates an average pooling layer. + /// + /// - Parameters: + /// - poolSize: Factor by which to downscale. + /// - stride: The stride. + /// - padding: The padding. + public init( + poolSize: Int, + stride: Int, + padding: Padding + ) { + self.poolSize = Int32(poolSize) + self.stride = Int32(stride) + self.padding = padding + } + + /// Returns the output obtained from applying the layer to the given input. + /// + /// - Parameters: + /// - input: The input to the layer. + /// - context: The contextual information for the layer application, e.g. the current learning + /// phase. + /// - Returns: The output. + @differentiable + public func applied(to input: Tensor, in _: Context) -> Tensor { + return input.expandingShape(at: 1).averagePooled( + kernelSize: (1, 1, poolSize, 1), strides: (1, 1, stride, 1), padding: padding + ).squeezingShape(at: 1) + } +} + /// An average pooling layer for spatial data. @_fixed_layout public struct AvgPool2D: Layer { From 74fa216a1ebf887c6952da99c84cdaa73c2c7ccf Mon Sep 17 00:00:00 2001 From: Dave Fernandes Date: Sun, 17 Mar 2019 18:21:44 -0400 Subject: [PATCH 2/7] Fixed input dimensions --- Sources/DeepLearning/Layer.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index 11dc7600b..bbc351f9a 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -383,7 +383,7 @@ public struct Conv1D: Layer { /// - Returns: The output. @differentiable public func applied(to input: Tensor, in _: Context) -> Tensor { - let conv2D = input.convolved2D(withFilter: filter.expandingShape(at: 1), + let conv2D = input.expandingShape(at: 1).convolved2D(withFilter: filter.expandingShape(at: 1), strides: (1, 1, stride, 1), padding: padding) return activation(conv2D.squeezingShape(at: 1) + bias) } @@ -428,7 +428,7 @@ public extension Conv1D { /// /// - Parameters: /// - filterShape: The shape of the filter, represented by a tuple of `4` integers. - /// - strides: The strides. + /// - stride: The stride. /// - padding: The padding. /// - activation: The activation function. /// - seed: The random seed for initialization. The default value is random. From 1df460a5d66acc49df7eb826c546c35b4e4a29ce Mon Sep 17 00:00:00 2001 From: Dave Fernandes Date: Tue, 19 Mar 2019 12:24:40 -0400 Subject: [PATCH 3/7] Fixed filter dimension; added tests --- Sources/DeepLearning/Layer.swift | 16 +++---- Tests/DeepLearningTests/LayerTests.swift | 48 +++++++++++++++++++ Tests/DeepLearningTests/XCTestManifests.swift | 1 + 3 files changed, 57 insertions(+), 8 deletions(-) create mode 100644 Tests/DeepLearningTests/LayerTests.swift diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index bbc351f9a..96934920e 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -355,8 +355,8 @@ public struct Conv1D: Layer { /// padding. /// /// - Parameters: - /// - filter: The filter. - /// - bias: The bias. + /// - filter: The filter (width, inputChannels, outputChannels). + /// - bias: The bias (dimensions: output channels). /// - activation: The activation activation. /// - stride: The stride. /// - padding: The padding. @@ -377,13 +377,13 @@ public struct Conv1D: Layer { /// Returns the output obtained from applying the layer to the given input. /// /// - Parameters: - /// - input: The input to the layer. + /// - input: The input to the layer (batchCount, width, inputChannels). /// - context: The contextual information for the layer application, e.g. the current learning /// phase. - /// - Returns: The output. + /// - Returns: The output (batchCount, newWidth, outputChannels). @differentiable public func applied(to input: Tensor, in _: Context) -> Tensor { - let conv2D = input.expandingShape(at: 1).convolved2D(withFilter: filter.expandingShape(at: 1), + let conv2D = input.expandingShape(at: 1).convolved2D(withFilter: filter.expandingShape(at: 0), strides: (1, 1, stride, 1), padding: padding) return activation(conv2D.squeezingShape(at: 1) + bias) } @@ -395,7 +395,7 @@ public extension Conv1D where Scalar.RawSignificand: FixedWidthInteger { /// initialization with the specified generator. The bias vector is initialized with zeros. /// /// - Parameters: - /// - filterShape: The shape of the filter, represented by a tuple of `3` integers. + /// - filterShape: The shape of the filter (width, inputChannels, outputChannels). /// - stride: The stride. /// - padding: The padding. /// - activation: The activation function. @@ -422,12 +422,12 @@ public extension Conv1D where Scalar.RawSignificand: FixedWidthInteger { } public extension Conv1D { - /// Creates a `Conv2D` layer with the specified filter shape, strides, padding, and + /// Creates a `Conv1D` layer with the specified filter shape, strides, padding, and /// element-wise activation function. The filter tensor is initialized using Glorot uniform /// initialization with the specified seed. The bias vector is initialized with zeros. /// /// - Parameters: - /// - filterShape: The shape of the filter, represented by a tuple of `4` integers. + /// - filterShape: The shape of the filter (width, inputChannels, outputChannels). /// - stride: The stride. /// - padding: The padding. /// - activation: The activation function. diff --git a/Tests/DeepLearningTests/LayerTests.swift b/Tests/DeepLearningTests/LayerTests.swift new file mode 100644 index 000000000..a0f75703b --- /dev/null +++ b/Tests/DeepLearningTests/LayerTests.swift @@ -0,0 +1,48 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import XCTest +@testable import DeepLearning + +final class LayerTests: XCTestCase { + func testConv1D() { + let filter = Tensor(ones: [3, 1, 2]) * Tensor([[[0.33333333, 1]]]) + let bias = Tensor([0, 1]) + let layer = Conv1D(filter: filter, bias: bias, activation: identity, stride: 1, padding: .valid) + let input = Tensor([[0, 1, 2, 3, 4], [10, 11, 12, 13, 14]]).expandingShape(at: 2) + let output = layer.inferring(from: input) + let expected = Tensor([[[1, 4], [2, 7], [3, 10]], [[11, 34], [12, 37], [13, 40]]]) + XCTAssertEqual(round(output), expected) + } + + func testMaxPool1D() { + let layer = MaxPool1D(poolSize: 3, stride: 1, padding: .valid) + let input = Tensor([[0, 1, 2, 3, 4], [10, 11, 12, 13, 14]]).expandingShape(at: 2) + let output = layer.inferring(from: input) + let expected = Tensor([[[2], [3], [4]], [[12], [13], [14]]]) + XCTAssertEqual(round(output), expected) + } + + func testAvgPool1D() { + let layer = AvgPool1D(poolSize: 3, stride: 1, padding: .valid) + let input = Tensor([[0, 1, 2, 3, 4], [10, 11, 12, 13, 14]]).expandingShape(at: 2) + let output = layer.inferring(from: input) + let expected = Tensor([[[1], [2], [3]], [[11], [12], [13]]]) + XCTAssertEqual(round(output), expected) + } + + static var allTests = [ + ("testConv1D", testConv1D), ("testMaxPool1D", testMaxPool1D), ("testAvgPool1D", testAvgPool1D) + ] +} diff --git a/Tests/DeepLearningTests/XCTestManifests.swift b/Tests/DeepLearningTests/XCTestManifests.swift index a731b951c..96a9048a5 100644 --- a/Tests/DeepLearningTests/XCTestManifests.swift +++ b/Tests/DeepLearningTests/XCTestManifests.swift @@ -21,6 +21,7 @@ public func allTests() -> [XCTestCaseEntry] { testCase(PRNGTests.allTests), testCase(TrivialModelTests.allTests), testCase(SequentialTests.allTests), + testCase(LayerTests.allTests), ] } #endif From c8b7e730fa1b57b6be9b69e5ba6ccd0bb051118a Mon Sep 17 00:00:00 2001 From: Dan Zheng Date: Tue, 19 Mar 2019 18:16:49 -0400 Subject: [PATCH 4/7] Update Sources/DeepLearning/Layer.swift Co-Authored-By: dave-fernandes --- Sources/DeepLearning/Layer.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index 96934920e..31fa5b3da 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -357,7 +357,7 @@ public struct Conv1D: Layer { /// - Parameters: /// - filter: The filter (width, inputChannels, outputChannels). /// - bias: The bias (dimensions: output channels). - /// - activation: The activation activation. + /// - activation: The element-wise activation function. /// - stride: The stride. /// - padding: The padding. public init( From 827a1852f40bbbc5660a27f58f3cd397a4e7e4a6 Mon Sep 17 00:00:00 2001 From: Dave Fernandes Date: Tue, 19 Mar 2019 21:11:08 -0400 Subject: [PATCH 5/7] Comments and spacing issue fixes --- Sources/DeepLearning/Layer.swift | 80 ++++++++++++++++---------------- 1 file changed, 41 insertions(+), 39 deletions(-) diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index 31fa5b3da..8e1e6f674 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -338,9 +338,9 @@ public extension Dense { /// tensor of outputs. @_fixed_layout public struct Conv1D: Layer { - /// The 3-D convolution kernel. + /// The 3-D convolution kernel `[width, inputChannels, outputChannels]`. public var filter: Tensor - /// The bias vector. + /// The bias vector `[outputChannels]`. public var bias: Tensor /// An activation function. public typealias Activation = @differentiable (Tensor) -> Tensor @@ -355,11 +355,11 @@ public struct Conv1D: Layer { /// padding. /// /// - Parameters: - /// - filter: The filter (width, inputChannels, outputChannels). - /// - bias: The bias (dimensions: output channels). + /// - filter: The 3-D convolution kernel `[width, inputChannels, outputChannels]`. + /// - bias: The bias vector `[outputChannels]`. /// - activation: The element-wise activation function. - /// - stride: The stride. - /// - padding: The padding. + /// - stride: The stride of the sliding window for temporal dimension. + /// - padding: The padding algorithm for convolution. public init( filter: Tensor, bias: Tensor, @@ -377,14 +377,14 @@ public struct Conv1D: Layer { /// Returns the output obtained from applying the layer to the given input. /// /// - Parameters: - /// - input: The input to the layer (batchCount, width, inputChannels). + /// - input: The input to the layer `[batchCount, width, inputChannels]`. /// - context: The contextual information for the layer application, e.g. the current learning /// phase. - /// - Returns: The output (batchCount, newWidth, outputChannels). + /// - Returns: The output `[batchCount, newWidth, outputChannels]`. @differentiable public func applied(to input: Tensor, in _: Context) -> Tensor { - let conv2D = input.expandingShape(at: 1).convolved2D(withFilter: filter.expandingShape(at: 0), - strides: (1, 1, stride, 1), padding: padding) + let conv2D = input.expandingShape(at: 1).convolved2D( + withFilter: filter.expandingShape(at: 0), strides: (1, 1, stride, 1), padding: padding) return activation(conv2D.squeezingShape(at: 1) + bias) } } @@ -395,10 +395,11 @@ public extension Conv1D where Scalar.RawSignificand: FixedWidthInteger { /// initialization with the specified generator. The bias vector is initialized with zeros. /// /// - Parameters: - /// - filterShape: The shape of the filter (width, inputChannels, outputChannels). - /// - stride: The stride. - /// - padding: The padding. - /// - activation: The activation function. + /// - filterShape: The 3-D shape of the filter, representing + /// `[width, inputChannels, outputChannels]`. + /// - stride: The stride of the sliding window for temporal dimension. + /// - padding: The padding algorithm for convolution. + /// - activation: The element-wise activation function. /// - generator: The random number generator for initialization. /// /// - Note: Use `init(filterShape:stride:padding:activation:seed:)` for faster random @@ -427,10 +428,11 @@ public extension Conv1D { /// initialization with the specified seed. The bias vector is initialized with zeros. /// /// - Parameters: - /// - filterShape: The shape of the filter (width, inputChannels, outputChannels). - /// - stride: The stride. - /// - padding: The padding. - /// - activation: The activation function. + /// - filterShape: The 3-D shape of the filter, representing + /// `[width, inputChannels, outputChannels]`. + /// - stride: The stride of the sliding window for temporal dimension. + /// - padding: The padding algorithm for convolution. + /// - activation: The element-wise activation function. /// - seed: The random seed for initialization. The default value is random. init( filterShape: (Int, Int, Int), @@ -474,11 +476,11 @@ public struct Conv2D: Layer { /// padding. /// /// - Parameters: - /// - filter: The filter. - /// - bias: The bias. - /// - activation: The activation activation. - /// - strides: The strides. - /// - padding: The padding. + /// - filter: The 4-D convolution kernel. + /// - bias: The bias vector. + /// - activation: The element-wise activation function. + /// - strides: The strides of the sliding window for spatial dimensions. + /// - padding: The padding algorithm for convolution. public init( filter: Tensor, bias: Tensor, @@ -514,10 +516,10 @@ public extension Conv2D { /// initialization with the specified generator. The bias vector is initialized with zeros. /// /// - Parameters: - /// - filterShape: The shape of the filter, represented by a tuple of `4` integers. - /// - strides: The strides. - /// - padding: The padding. - /// - activation: The activation function. + /// - filterShape: The shape of the 4-D convolution kernel. + /// - strides: The strides of the sliding window for spatial dimensions. + /// - padding: The padding algorithm for convolution. + /// - activation: The element-wise activation function. /// - generator: The random number generator for initialization. /// /// - Note: Use `init(filterShape:strides:padding:activation:seed:)` for faster random @@ -547,10 +549,10 @@ public extension Conv2D { /// initialization with the specified seed. The bias vector is initialized with zeros. /// /// - Parameters: - /// - filterShape: The shape of the filter, represented by a tuple of `4` integers. - /// - strides: The strides. - /// - padding: The padding. - /// - activation: The activation function. + /// - filterShape: The shape of the 4-D convolution kernel. + /// - strides: The strides of the sliding window for spatial dimensions. + /// - padding: The padding algorithm for convolution. + /// - activation: The element-wise activation function. /// - seed: The random seed for initialization. The default value is random. init( filterShape: (Int, Int, Int, Int), @@ -710,9 +712,9 @@ public struct MaxPool1D: Layer { /// Creates a max pooling layer. /// /// - Parameters: - /// - poolSize: Factor by which to downscale. - /// - stride: The stride. - /// - padding: The padding. + /// - poolSize: The size of the sliding reduction window for pooling. + /// - stride: The stride of the sliding window for temporal dimension. + /// - padding: The padding algorithm for pooling. public init( poolSize: Int, stride: Int, @@ -733,7 +735,7 @@ public struct MaxPool1D: Layer { @differentiable public func applied(to input: Tensor, in _: Context) -> Tensor { return input.expandingShape(at: 1).maxPooled( - kernelSize: (1, 1, poolSize, 1), strides: (1, 1, stride, 1), padding: padding + kernelSize: (1, 1, poolSize, 1), strides: (1, 1, stride, 1), padding: padding ).squeezingShape(at: 1) } } @@ -801,9 +803,9 @@ public struct AvgPool1D: Layer { /// Creates an average pooling layer. /// /// - Parameters: - /// - poolSize: Factor by which to downscale. - /// - stride: The stride. - /// - padding: The padding. + /// - poolSize: The size of the sliding reduction window for pooling. + /// - stride: The stride of the sliding window for temporal dimension. + /// - padding: The padding algorithm for pooling. public init( poolSize: Int, stride: Int, @@ -824,7 +826,7 @@ public struct AvgPool1D: Layer { @differentiable public func applied(to input: Tensor, in _: Context) -> Tensor { return input.expandingShape(at: 1).averagePooled( - kernelSize: (1, 1, poolSize, 1), strides: (1, 1, stride, 1), padding: padding + kernelSize: (1, 1, poolSize, 1), strides: (1, 1, stride, 1), padding: padding ).squeezingShape(at: 1) } } From deb771b111b7f527d6f4cd5965124344bb41142e Mon Sep 17 00:00:00 2001 From: Dave Fernandes Date: Tue, 19 Mar 2019 21:16:01 -0400 Subject: [PATCH 6/7] more indentation fixes --- Sources/DeepLearning/Layer.swift | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index 8e1e6f674..6c2688d1b 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -445,12 +445,12 @@ public extension Conv1D { let filterTensorShape = TensorShape([ Int32(filterShape.0), Int32(filterShape.1), Int32(filterShape.2)]) self.init( - filter: Tensor(glorotUniform: filterTensorShape, seed: seed), - bias: Tensor(zeros: TensorShape([Int32(filterShape.2)])), - activation: activation, - stride: Int32(stride), - padding: padding) - } + filter: Tensor(glorotUniform: filterTensorShape, seed: seed), + bias: Tensor(zeros: TensorShape([Int32(filterShape.2)])), + activation: activation, + stride: Int32(stride), + padding: padding) + } } /// A 2-D convolution layer (e.g. spatial convolution over images). @@ -566,11 +566,11 @@ public extension Conv2D { Int32(filterShape.0), Int32(filterShape.1), Int32(filterShape.2), Int32(filterShape.3)]) self.init( - filter: Tensor(glorotUniform: filterTensorShape, seed: seed), - bias: Tensor(zeros: TensorShape([Int32(filterShape.3)])), - activation: activation, - strides: (Int32(strides.0), Int32(strides.1)), - padding: padding) + filter: Tensor(glorotUniform: filterTensorShape, seed: seed), + bias: Tensor(zeros: TensorShape([Int32(filterShape.3)])), + activation: activation, + strides: (Int32(strides.0), Int32(strides.1)), + padding: padding) } } @@ -786,7 +786,7 @@ public struct MaxPool2D: Layer { @differentiable public func applied(to input: Tensor, in _: Context) -> Tensor { return input.maxPooled( - kernelSize: poolSize, strides: strides, padding: padding) + kernelSize: poolSize, strides: strides, padding: padding) } } From a76891dac9f8e312f4d19ec6263ae2e0b4c9b391 Mon Sep 17 00:00:00 2001 From: Richard Wei Date: Wed, 20 Mar 2019 00:25:19 -0400 Subject: [PATCH 7/7] Update Sources/DeepLearning/Layer.swift Co-Authored-By: dave-fernandes --- Sources/DeepLearning/Layer.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/DeepLearning/Layer.swift b/Sources/DeepLearning/Layer.swift index 6c2688d1b..a71189183 100644 --- a/Sources/DeepLearning/Layer.swift +++ b/Sources/DeepLearning/Layer.swift @@ -450,7 +450,7 @@ public extension Conv1D { activation: activation, stride: Int32(stride), padding: padding) - } + } } /// A 2-D convolution layer (e.g. spatial convolution over images).