diff --git a/CIFAR/WideResNet.swift b/CIFAR/WideResNet.swift index a58b0529714..12678a47d87 100644 --- a/CIFAR/WideResNet.swift +++ b/CIFAR/WideResNet.swift @@ -20,14 +20,46 @@ import TensorFlow // https://arxiv.org/abs/1605.07146 // https://github.com/szagoruyko/wide-residual-networks -struct BatchNormConv2DBlock: Layer { +struct IdentityLayer: Layer { typealias Input = Tensor typealias Output = Tensor var norm1: BatchNorm var conv1: Conv2D var norm2: BatchNorm + let dropout = Dropout(probability: 0.3) var conv2: Conv2D + + + init( + filterShape: (Int, Int, Int, Int), + padding: Padding = .same + ) { + self.norm1 = BatchNorm(featureCount: filterShape.3) + self.conv1 = Conv2D(filterShape: filterShape, strides: (1,1), padding: padding) + self.norm2 = BatchNorm(featureCount: filterShape.3) + self.conv2 = Conv2D(filterShape: filterShape, strides: (1, 1), padding: padding) + } + + @differentiable + func call(_ input: Input) -> Output { + let preactivation1 = relu(norm1(input)) + let firstLayer = conv1(preactivation1) + let preactivation2 = dropout(relu(norm2(firstLayer))) + return conv2(preactivation2) + input + } +} + +struct ExpansionLayer: Layer { + typealias Input = Tensor + typealias Output = Tensor + + var norm1: BatchNorm + var conv1: Conv2D + var norm2: BatchNorm + let dropout = Dropout(probability: 0.3) + var conv2: Conv2D + var shortcut: Conv2D init( filterShape: (Int, Int, Int, Int), @@ -37,61 +69,53 @@ struct BatchNormConv2DBlock: Layer { self.norm1 = BatchNorm(featureCount: filterShape.2) self.conv1 = Conv2D(filterShape: filterShape, strides: strides, padding: padding) self.norm2 = BatchNorm(featureCount: filterShape.3) - self.conv2 = Conv2D(filterShape: filterShape, strides: (1, 1), padding: padding) + self.conv2 = Conv2D(filterShape: (filterShape.0, filterShape.1, + filterShape.3, filterShape.3), + strides: (1, 1), padding: padding) + self.shortcut = Conv2D(filterShape: (1, 1, filterShape.2, filterShape.3), + strides: strides, padding: padding) } @differentiable func call(_ input: Input) -> Output { - let firstLayer = conv1(relu(norm1(input))) - return conv2(relu(norm2(firstLayer))) + let preactivation1 = relu(norm1(input)) + let firstLayer = conv1(preactivation1) + let preactivation2 = dropout(relu(norm2(firstLayer))) + return conv2(preactivation2) + shortcut(preactivation1) } } struct WideResNetBasicBlock: Layer { typealias Input = Tensor typealias Output = Tensor - - var blocks: [BatchNormConv2DBlock] - var shortcut: Conv2D + var expansion: ExpansionLayer + var blocks: [IdentityLayer] init( featureCounts: (Int, Int), kernelSize: Int = 3, depthFactor: Int = 2, - widenFactor: Int = 1, initialStride: (Int, Int) = (2, 2) ) { - if initialStride == (1, 1) { - self.blocks = [BatchNormConv2DBlock( - filterShape: (kernelSize, kernelSize, - featureCounts.0, featureCounts.1 * widenFactor), - strides: initialStride)] - self.shortcut = Conv2D( - filterShape: (1, 1, featureCounts.0, featureCounts.1 * widenFactor), - strides: initialStride) - } else { - self.blocks = [BatchNormConv2DBlock( - filterShape: (kernelSize, kernelSize, - featureCounts.0 * widenFactor, featureCounts.1 * widenFactor), - strides: initialStride)] - self.shortcut = Conv2D( - filterShape: (1, 1, featureCounts.0 * widenFactor, featureCounts.1 * widenFactor), - strides: initialStride) - } + self.expansion = ExpansionLayer( + filterShape: (kernelSize, kernelSize, + featureCounts.0, featureCounts.1), + strides: initialStride) + self.blocks = [] for _ in 1.. Output { - let blocksReduced = blocks.differentiableReduce(input) { last, layer in - relu(layer(last)) + var net = expansion(input) + net = blocks.differentiableReduce(net) { last, layer in + layer(last) } - return relu(blocksReduced + shortcut(input)) + return net } } @@ -111,18 +135,23 @@ struct WideResNet: Layer { var classifier: Dense init(depthFactor: Int = 2, widenFactor: Int = 8) { - self.l1 = Conv2D(filterShape: (3, 3, 3, 16), strides: (1, 1), padding: .same) - - l2 = WideResNetBasicBlock(featureCounts: (16, 16), depthFactor: depthFactor, - widenFactor: widenFactor, initialStride: (1, 1)) - l3 = WideResNetBasicBlock(featureCounts: (16, 32), depthFactor: depthFactor, - widenFactor: widenFactor) - l4 = WideResNetBasicBlock(featureCounts: (32, 64), depthFactor: depthFactor, - widenFactor: widenFactor) + let featureCount1 = 16 + let featureCount2 = 16 * widenFactor + let featureCount3 = 32 * widenFactor + let featureCount4 = 64 * widenFactor + self.l1 = Conv2D(filterShape: (3, 3, 3, featureCount1), strides: (1, 1), padding: .same) + + l2 = WideResNetBasicBlock(featureCounts: (featureCount1, featureCount2), + depthFactor: depthFactor, + initialStride: (1, 1)) + l3 = WideResNetBasicBlock(featureCounts: (featureCount2, featureCount3), + depthFactor: depthFactor) + l4 = WideResNetBasicBlock(featureCounts: (featureCount3, featureCount4), + depthFactor: depthFactor) - self.norm = BatchNorm(featureCount: 64 * widenFactor) + self.norm = BatchNorm(featureCount: featureCount4) self.avgPool = AvgPool2D(poolSize: (8, 8), strides: (8, 8)) - self.classifier = Dense(inputSize: 64 * widenFactor, outputSize: 10) + self.classifier = Dense(inputSize: featureCount4, outputSize: 10) } @differentiable