diff --git a/Models/ImageClassification/WideResNet.swift b/Models/ImageClassification/WideResNet.swift index 79db5cb3753..2852eb7c33b 100644 --- a/Models/ImageClassification/WideResNet.swift +++ b/Models/ImageClassification/WideResNet.swift @@ -25,79 +25,67 @@ public struct BatchNormConv2DBlock: Layer { public var conv1: Conv2D public var norm2: BatchNorm public var conv2: Conv2D + public var shortcut: Conv2D + let isExpansion: Bool + let dropout: Dropout = Dropout(probability: 0.3) public init( - filterShape: (Int, Int, Int, Int), + featureCounts: (Int, Int), + kernelSize: Int = 3, strides: (Int, Int) = (1, 1), padding: Padding = .same ) { - self.norm1 = BatchNorm(featureCount: filterShape.2) - self.conv1 = Conv2D(filterShape: filterShape, strides: strides, padding: padding) - self.norm2 = BatchNorm(featureCount: filterShape.3) - self.conv2 = Conv2D(filterShape: filterShape, strides: (1, 1), padding: padding) + self.norm1 = BatchNorm(featureCount: featureCounts.0) + self.conv1 = Conv2D( + filterShape: (kernelSize, kernelSize, featureCounts.0, featureCounts.1), + strides: strides, + padding: padding) + self.norm2 = BatchNorm(featureCount: featureCounts.1) + self.conv2 = Conv2D(filterShape: (kernelSize, kernelSize, featureCounts.1, featureCounts.1), + strides: (1, 1), + padding: padding) + self.shortcut = Conv2D(filterShape: (1, 1, featureCounts.0, featureCounts.1), + strides: strides, + padding: padding) + self.isExpansion = featureCounts.1 != featureCounts.0 || strides != (1, 1) } @differentiable public func callAsFunction(_ input: Tensor) -> Tensor { - let firstLayer = conv1(relu(norm1(input))) - return conv2(relu(norm2(firstLayer))) + let preact1 = relu(norm1(input)) + var residual = conv1(preact1) + let preact2: Tensor + let shortcutResult: Tensor + if isExpansion { + shortcutResult = shortcut(preact1) + preact2 = relu(norm2(residual)) + } else { + shortcutResult = input + preact2 = dropout(relu(norm2(residual))) + } + residual = conv2(preact2) + return residual + shortcutResult } } public struct WideResNetBasicBlock: Layer { public var blocks: [BatchNormConv2DBlock] - public var shortcut: Conv2D public init( featureCounts: (Int, Int), kernelSize: Int = 3, depthFactor: Int = 2, - widenFactor: Int = 1, initialStride: (Int, Int) = (2, 2) ) { - if initialStride == (1, 1) { - self.blocks = [ - BatchNormConv2DBlock( - filterShape: ( - kernelSize, kernelSize, - featureCounts.0, featureCounts.1 * widenFactor - ), - strides: initialStride) - ] - self.shortcut = Conv2D( - filterShape: (1, 1, featureCounts.0, featureCounts.1 * widenFactor), - strides: initialStride) - } else { - self.blocks = [ - BatchNormConv2DBlock( - filterShape: ( - kernelSize, kernelSize, - featureCounts.0 * widenFactor, featureCounts.1 * widenFactor - ), - strides: initialStride) - ] - self.shortcut = Conv2D( - filterShape: (1, 1, featureCounts.0 * widenFactor, featureCounts.1 * widenFactor), - strides: initialStride) - } + self.blocks = [BatchNormConv2DBlock(featureCounts: featureCounts, strides: initialStride)] for _ in 1..) -> Tensor { - let blocksReduced = blocks.differentiableReduce(input) { last, layer in - relu(layer(last)) - } - return relu(blocksReduced + shortcut(input)) + return blocks.differentiableReduce(input) { $1($0) } } } @@ -116,15 +104,12 @@ public struct WideResNet: Layer { public init(depthFactor: Int = 2, widenFactor: Int = 8) { self.l1 = Conv2D(filterShape: (3, 3, 3, 16), strides: (1, 1), padding: .same) - l2 = WideResNetBasicBlock( - featureCounts: (16, 16), depthFactor: depthFactor, - widenFactor: widenFactor, initialStride: (1, 1)) - l3 = WideResNetBasicBlock( - featureCounts: (16, 32), depthFactor: depthFactor, - widenFactor: widenFactor) - l4 = WideResNetBasicBlock( - featureCounts: (32, 64), depthFactor: depthFactor, - widenFactor: widenFactor) + self.l2 = WideResNetBasicBlock( + featureCounts: (16, 16 * widenFactor), depthFactor: depthFactor, initialStride: (1, 1)) + self.l3 = WideResNetBasicBlock(featureCounts: (16 * widenFactor, 32 * widenFactor), + depthFactor: depthFactor) + self.l4 = WideResNetBasicBlock(featureCounts: (32 * widenFactor, 64 * widenFactor), + depthFactor: depthFactor) self.norm = BatchNorm(featureCount: 64 * widenFactor) self.avgPool = AvgPool2D(poolSize: (8, 8), strides: (8, 8))