diff --git a/Models/ImageClassification/SqueezeNet.swift b/Models/ImageClassification/SqueezeNet.swift index b1c17f08ad0..00f3dcecff5 100644 --- a/Models/ImageClassification/SqueezeNet.swift +++ b/Models/ImageClassification/SqueezeNet.swift @@ -52,8 +52,12 @@ public struct Fire: Layer { } } -public struct SqueezeNet: Layer { - public var conv1 = Conv2D(filterShape: (7, 7, 3, 96), strides: (2, 2), padding: .same) +public struct SqueezeNetV1_0: Layer { + public var conv1 = Conv2D( + filterShape: (7, 7, 3, 96), + strides: (2, 2), + padding: .same, + activation: relu) public var maxPool1 = MaxPool2D(poolSize: (3, 3), strides: (2, 2)) public var fire2 = Fire( inputFilterCount: 96, @@ -102,7 +106,7 @@ public struct SqueezeNet: Layer { public var dropout = Dropout(probability: 0.5) public init(classCount: Int) { - conv10 = Conv2D(filterShape: (1, 1, 512, classCount), strides: (1, 1)) + conv10 = Conv2D(filterShape: (1, 1, 512, classCount), strides: (1, 1), activation: relu) } @differentiable @@ -115,3 +119,71 @@ public struct SqueezeNet: Layer { return convolved2 } } + +public struct SqueezeNetV1_1: Layer { + public var conv1 = Conv2D( + filterShape: (3, 3, 3, 64), + strides: (2, 2), + padding: .same, + activation: relu) + public var maxPool1 = MaxPool2D(poolSize: (3, 3), strides: (2, 2)) + public var fire2 = Fire( + inputFilterCount: 64, + squeezeFilterCount: 16, + expand1FilterCount: 64, + expand3FilterCount: 64) + public var fire3 = Fire( + inputFilterCount: 128, + squeezeFilterCount: 16, + expand1FilterCount: 64, + expand3FilterCount: 64) + public var maxPool3 = MaxPool2D(poolSize: (3, 3), strides: (2, 2)) + public var fire4 = Fire( + inputFilterCount: 128, + squeezeFilterCount: 32, + expand1FilterCount: 128, + expand3FilterCount: 128) + public var fire5 = Fire( + inputFilterCount: 256, + squeezeFilterCount: 32, + expand1FilterCount: 128, + expand3FilterCount: 128) + public var maxPool5 = MaxPool2D(poolSize: (3, 3), strides: (2, 2)) + public var fire6 = Fire( + inputFilterCount: 256, + squeezeFilterCount: 48, + expand1FilterCount: 192, + expand3FilterCount: 192) + public var fire7 = Fire( + inputFilterCount: 384, + squeezeFilterCount: 48, + expand1FilterCount: 192, + expand3FilterCount: 192) + public var fire8 = Fire( + inputFilterCount: 384, + squeezeFilterCount: 64, + expand1FilterCount: 256, + expand3FilterCount: 256) + public var fire9 = Fire( + inputFilterCount: 512, + squeezeFilterCount: 64, + expand1FilterCount: 256, + expand3FilterCount: 256) + public var conv10: Conv2D + public var avgPool10 = AvgPool2D(poolSize: (13, 13), strides: (1, 1)) + public var dropout = Dropout(probability: 0.5) + + public init(classCount: Int) { + conv10 = Conv2D(filterShape: (1, 1, 512, classCount), strides: (1, 1), activation: relu) + } + + @differentiable + public func callAsFunction(_ input: Tensor) -> Tensor { + let convolved1 = input.sequenced(through: conv1, maxPool1) + let fired1 = convolved1.sequenced(through: fire2, fire3, maxPool3, fire4, fire5) + let fired2 = fired1.sequenced(through: maxPool5, fire6, fire7, fire8, fire9) + let convolved2 = fired2.sequenced(through: dropout, conv10, avgPool10) + .reshaped(to: [input.shape[0], conv10.filter.shape[3]]) + return convolved2 + } +} diff --git a/Tests/ImageClassificationTests/Inference.swift b/Tests/ImageClassificationTests/Inference.swift index fb1e173f5b8..710cc691ac6 100644 --- a/Tests/ImageClassificationTests/Inference.swift +++ b/Tests/ImageClassificationTests/Inference.swift @@ -92,11 +92,20 @@ final class ImageClassificationInferenceTests: XCTestCase { XCTAssertEqual(resNet34ImageNetResult.shape, [1, 1000]) } - func testSqueezeNet() { + func testSqueezeNetV1_0() { let input = Tensor( randomNormal: [1, 224, 224, 3], mean: Tensor(0.5), standardDeviation: Tensor(0.1), seed: (0xffeffe, 0xfffe)) - let squeezeNet = SqueezeNet(classCount: 1000) + let squeezeNet = SqueezeNetV1_0(classCount: 1000) + let squeezeNetResult = squeezeNet(input) + XCTAssertEqual(squeezeNetResult.shape, [1, 1000]) + } + + func testSqueezeNetV1_1() { + let input = Tensor( + randomNormal: [1, 224, 224, 3], mean: Tensor(0.5), + standardDeviation: Tensor(0.1), seed: (0xffeffe, 0xfffe)) + let squeezeNet = SqueezeNetV1_1(classCount: 1000) let squeezeNetResult = squeezeNet(input) XCTAssertEqual(squeezeNetResult.shape, [1, 1000]) } @@ -152,7 +161,8 @@ extension ImageClassificationInferenceTests { ("testLeNet", testLeNet), ("testResNet", testResNet), ("testResNetV2", testResNetV2), - ("testSqueezeNet", testSqueezeNet), + ("testSqueezeNetV1_0", testSqueezeNetV1_0), + ("testSqueezeNetV1_1", testSqueezeNetV1_1), ("testWideResNet", testWideResNet), ] }