Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 36 additions & 31 deletions ResNet/ResNet.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ import TensorFlow
// https://arxiv.org/abs/1512.03385
// using shortcut layer to connect BasicBlock layers (aka Option (B))

enum InputType {
case cifar
case imagenet
}

struct ConvBN: Layer {
typealias Input = Tensor<Float>
typealias Output = Tensor<Float>
Expand Down Expand Up @@ -219,16 +224,18 @@ struct ResNetBasic: Layer {
var flatten = Flatten<Float>()
var classifier: Dense<Float>

init(imageSize: Int, classCount: Int, layerBlockCounts: (Int, Int, Int, Int)) {
// default to the ImageNet case where imageSize == 224
// Swift requires that all properties get initialized outside control flow
l1 = ConvBN(filterShape: (7, 7, 3, 64), strides: (2, 2), padding: .same)
maxPool = MaxPool2D(poolSize: (3, 3), strides: (2, 2))
avgPool = AvgPool2D(poolSize: (7, 7), strides: (7, 7))
if imageSize == 32 {
l1 = ConvBN(filterShape: (3, 3, 3, 64), padding: .same)
maxPool = MaxPool2D(poolSize: (1, 1), strides: (1, 1)) // no-op
avgPool = AvgPool2D(poolSize: (4, 4), strides: (4, 4))
init(input: InputType, layerBlockCounts: (Int, Int, Int, Int)) {
switch input {
case .imagenet:
l1 = ConvBN(filterShape: (7, 7, 3, 64), strides: (2, 2), padding: .same)
maxPool = MaxPool2D(poolSize: (3, 3), strides: (2, 2))
avgPool = AvgPool2D(poolSize: (7, 7), strides: (7, 7))
classifier = Dense(inputSize: 512, outputSize: 1000)
case .cifar:
l1 = ConvBN(filterShape: (3, 3, 3, 64), padding: .same)
maxPool = MaxPool2D(poolSize: (1, 1), strides: (1, 1)) // no-op
avgPool = AvgPool2D(poolSize: (4, 4), strides: (4, 4))
classifier = Dense(inputSize: 512, outputSize: 10)
}

l2b = ResidualBasicBlockStack(featureCounts: (64, 64, 64, 64),
Expand All @@ -239,8 +246,6 @@ struct ResNetBasic: Layer {
blockCount: layerBlockCounts.2)
l5b = ResidualBasicBlockStack(featureCounts: (512, 512, 512, 512),
blockCount: layerBlockCounts.3)

classifier = Dense(inputSize: 512, outputSize: classCount)
}

@differentiable
Expand All @@ -260,12 +265,12 @@ extension ResNetBasic {
case resNet34
}

init(kind: Kind, imageSize: Int, classCount: Int) {
init(kind: Kind, type: InputType) {
switch kind {
case .resNet18:
self.init(imageSize: imageSize, classCount: classCount, layerBlockCounts: (2, 2, 2, 2))
self.init(input: type, layerBlockCounts: (2, 2, 2, 2))
case .resNet34:
self.init(imageSize: imageSize, classCount: classCount, layerBlockCounts: (3, 4, 6, 3))
self.init(input: type, layerBlockCounts: (3, 4, 6, 3))
}
}
}
Expand Down Expand Up @@ -293,16 +298,18 @@ struct ResNet: Layer {
var flatten = Flatten<Float>()
var classifier: Dense<Float>

init(imageSize: Int, classCount: Int, layerBlockCounts: (Int, Int, Int, Int)) {
// default to the ImageNet case where imageSize == 224
// Swift requires that all properties get initialized outside control flow
l1 = ConvBN(filterShape: (7, 7, 3, 64), strides: (2, 2), padding: .same)
maxPool = MaxPool2D(poolSize: (3, 3), strides: (2, 2))
avgPool = AvgPool2D(poolSize: (7, 7), strides: (7, 7))
if imageSize == 32 {
l1 = ConvBN(filterShape: (3, 3, 3, 64), padding: .same)
maxPool = MaxPool2D(poolSize: (1, 1), strides: (1, 1)) // no-op
avgPool = AvgPool2D(poolSize: (4, 4), strides: (4, 4))
init(input: InputType, layerBlockCounts: (Int, Int, Int, Int)) {
switch input {
case .imagenet:
l1 = ConvBN(filterShape: (7, 7, 3, 64), strides: (2, 2), padding: .same)
maxPool = MaxPool2D(poolSize: (3, 3), strides: (2, 2))
avgPool = AvgPool2D(poolSize: (7, 7), strides: (7, 7))
classifier = Dense(inputSize: 2048, outputSize: 1000)
case .cifar:
l1 = ConvBN(filterShape: (3, 3, 3, 64), padding: .same)
maxPool = MaxPool2D(poolSize: (1, 1), strides: (1, 1)) // no-op
avgPool = AvgPool2D(poolSize: (4, 4), strides: (4, 4))
classifier = Dense(inputSize: 2048, outputSize: 10)
}

l2b = ResidualIdentityBlockStack(featureCounts: (256, 64, 64, 256),
Expand All @@ -313,8 +320,6 @@ struct ResNet: Layer {
blockCount: layerBlockCounts.2)
l5b = ResidualIdentityBlockStack(featureCounts: (2048, 512, 512, 2048),
blockCount: layerBlockCounts.3)

classifier = Dense(inputSize: 2048, outputSize: classCount)
}

@differentiable
Expand All @@ -335,14 +340,14 @@ extension ResNet {
case resNet152
}

init(kind: Kind, imageSize: Int, classCount: Int) {
init(kind: Kind, type: InputType) {
switch kind {
case .resNet50:
self.init(imageSize: imageSize, classCount: classCount, layerBlockCounts: (3, 4, 6, 3))
self.init(input: type, layerBlockCounts: (3, 4, 6, 3))
case .resNet101:
self.init(imageSize: imageSize, classCount: classCount, layerBlockCounts: (3, 4, 23, 3))
self.init(input: type, layerBlockCounts: (3, 4, 23, 3))
case .resNet152:
self.init(imageSize: imageSize, classCount: classCount, layerBlockCounts: (3, 8, 36, 3))
self.init(input: type, layerBlockCounts: (3, 8, 36, 3))
}
}
}
2 changes: 1 addition & 1 deletion ResNet/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ let cifarDataset = loadCIFAR10()
let testBatches = cifarDataset.test.batched(batchSize)

// Use the network sized for CIFAR-10
var model = ResNet(kind: .resNet50, imageSize: 32, classCount: 10)
var model = ResNet(kind: .resNet50, type: .cifar)

// the classic ImageNet optimizer setting diverges on CIFAR-10
// let optimizer = SGD(for: model, learningRate: 0.1, momentum: 0.9)
Expand Down