diff --git a/.gitignore b/.gitignore
index 02872079f84..4fea2e7e10e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,4 +5,7 @@
*.xcodeproj
*.png
.DS_Store
+.swiftpm
cifar-10-batches-py/
+cifar-10-batches-bin/
+output/
diff --git a/.swift-format b/.swift-format
new file mode 100644
index 00000000000..8bc224a9773
--- /dev/null
+++ b/.swift-format
@@ -0,0 +1,14 @@
+{
+ "version": 1,
+ "lineLength": 100,
+ "indentation": {
+ "spaces": 4
+ },
+ "maximumBlankLines": 1,
+ "respectsExistingLineBreaks": true,
+ "blankLineBetweenMembers": {
+ "ignoreSingleLineProperties": true
+ },
+ "lineBreakBeforeControlFlowKeywords": false,
+ "lineBreakBeforeEachArgument": false
+}
diff --git a/Autoencoder/README.md b/Autoencoder/README.md
index ff7dad539c3..a635a82cc93 100644
--- a/Autoencoder/README.md
+++ b/Autoencoder/README.md
@@ -1,5 +1,7 @@
# Simple Autoencoder
+This is an example of a simple 1-dimensional autoencoder model, using MNIST as a training dataset. It should produce output similar to the following:
+
### Epoch 1
@@ -12,7 +14,6 @@
-This directory builds a simple 1-dimensional autoencoder model.
## Setup
@@ -23,12 +24,5 @@ installed. Make sure you've added the correct version of `swift` to your path.
To train the model, run:
```
-swift run Autoencoder
-```
-If you using brew to install python2 and modules, change the path:
- - remove brew path '/usr/local/bin'
- - add TensorFlow swift Toolchain /Library/Developer/Toolchains/swift-latest/usr/bin
-
+swift run -c release Autoencoder
```
-export PATH=/Library/Developer/Toolchains/swift-latest/usr/bin:/usr/bin:/bin:/usr/sbin:/sbin:"${PATH}"
-```
diff --git a/Autoencoder/main.swift b/Autoencoder/main.swift
index f575c00760c..5cecc2f732b 100644
--- a/Autoencoder/main.swift
+++ b/Autoencoder/main.swift
@@ -12,126 +12,61 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+import Datasets
import Foundation
+import ModelSupport
import TensorFlow
-import Python
-// Import Python modules
-let matplotlib = Python.import("matplotlib")
-let np = Python.import("numpy")
-let plt = Python.import("matplotlib.pyplot")
-
-// Turn off using display on server / linux
-matplotlib.use("Agg")
-
-// Some globals
let epochCount = 10
let batchSize = 100
-let outputFolder = "./output/"
-let imageHeight = 28, imageWidth = 28
-
-func plot(image: [Float], name: String) {
- // Create figure
- let ax = plt.gca()
- let array = np.array([image])
- let pixels = array.reshape([imageHeight, imageWidth])
- if !FileManager.default.fileExists(atPath: outputFolder) {
- try! FileManager.default.createDirectory(atPath: outputFolder,
- withIntermediateDirectories: false,
- attributes: nil)
- }
- ax.imshow(pixels, cmap: "gray")
- plt.savefig("\(outputFolder)\(name).png", dpi: 300)
- plt.close()
-}
+let imageHeight = 28
+let imageWidth = 28
-/// Reads a file into an array of bytes.
-func readFile(_ filename: String) -> [UInt8] {
- let possibleFolders = [".", "Resources", "Autoencoder/Resources"]
- for folder in possibleFolders {
- let parent = URL(fileURLWithPath: folder)
- let filePath = parent.appendingPathComponent(filename).path
- guard FileManager.default.fileExists(atPath: filePath) else {
- continue
- }
- let d = Python.open(filePath, "rb").read()
- return Array(numpy: np.frombuffer(d, dtype: np.uint8))!
- }
- print("Failed to find file with name \(filename) in the following folders: \(possibleFolders).")
- exit(-1)
-}
-
-/// Reads MNIST images and labels from specified file paths.
-func readMNIST(imagesFile: String, labelsFile: String) -> (images: Tensor,
- labels: Tensor) {
- print("Reading data.")
- let images = readFile(imagesFile).dropFirst(16).map { Float($0) }
- let labels = readFile(labelsFile).dropFirst(8).map { Int32($0) }
- let rowCount = labels.count
-
- print("Constructing data tensors.")
- return (
- images: Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images) / 255.0,
- labels: Tensor(labels)
- )
-}
-
-/// An autoencoder.
-struct Autoencoder: Layer {
- typealias Input = Tensor
- typealias Output = Tensor
-
- var encoder1 = Dense(inputSize: imageHeight * imageWidth, outputSize: 128,
- activation: relu)
- var encoder2 = Dense(inputSize: 128, outputSize: 64, activation: relu)
- var encoder3 = Dense(inputSize: 64, outputSize: 12, activation: relu)
- var encoder4 = Dense(inputSize: 12, outputSize: 3, activation: relu)
-
- var decoder1 = Dense(inputSize: 3, outputSize: 12, activation: relu)
- var decoder2 = Dense(inputSize: 12, outputSize: 64, activation: relu)
- var decoder3 = Dense(inputSize: 64, outputSize: 128, activation: relu)
- var decoder4 = Dense(inputSize: 128, outputSize: imageHeight * imageWidth,
- activation: tanh)
-
- @differentiable
- func call(_ input: Input) -> Output {
- let encoder = input.sequenced(through: encoder1, encoder2, encoder3, encoder4)
- return encoder.sequenced(through: decoder1, decoder2, decoder3, decoder4)
- }
-}
-
-// MNIST data logic
-func minibatch(in x: Tensor, at index: Int) -> Tensor {
- let start = index * batchSize
- return x[start..(inputSize: imageHeight * imageWidth, outputSize: 128, activation: relu)
+ Dense(inputSize: 128, outputSize: 64, activation: relu)
+ Dense(inputSize: 64, outputSize: 12, activation: relu)
+ Dense(inputSize: 12, outputSize: 3, activation: relu)
+ // The decoder.
+ Dense(inputSize: 3, outputSize: 12, activation: relu)
+ Dense(inputSize: 12, outputSize: 64, activation: relu)
+ Dense(inputSize: 64, outputSize: 128, activation: relu)
+ Dense(inputSize: 128, outputSize: imageHeight * imageWidth, activation: tanh)
}
-
-let (images, numericLabels) = readMNIST(imagesFile: "train-images-idx3-ubyte",
- labelsFile: "train-labels-idx1-ubyte")
-let labels = Tensor(oneHotAtIndices: numericLabels, depth: 10)
-
-var autoencoder = Autoencoder()
let optimizer = RMSProp(for: autoencoder)
// Training loop
for epoch in 1...epochCount {
- let sampleImage = Tensor(shape: [1, imageHeight * imageWidth], scalars: images[epoch].scalars)
+ let sampleImage = Tensor(
+ shape: [1, imageHeight * imageWidth], scalars: dataset.trainingImages[epoch].scalars)
let testImage = autoencoder(sampleImage)
- plot(image: sampleImage.scalars, name: "epoch-\(epoch)-input")
- plot(image: testImage.scalars, name: "epoch-\(epoch)-output")
+ do {
+ try saveImage(
+ sampleImage, size: (imageWidth, imageHeight), directory: outputFolder,
+ name: "epoch-\(epoch)-input")
+ try saveImage(
+ testImage, size: (imageWidth, imageHeight), directory: outputFolder,
+ name: "epoch-\(epoch)-output")
+ } catch {
+ print("Could not save image with error: \(error)")
+ }
let sampleLoss = meanSquaredError(predicted: testImage, expected: sampleImage)
print("[Epoch: \(epoch)] Loss: \(sampleLoss)")
- for i in 0 ..< Int(labels.shape[0]) / batchSize {
- let x = minibatch(in: images, at: i)
+ for i in 0 ..< dataset.trainingSize / batchSize {
+ let x = dataset.trainingImages.minibatch(at: i, batchSize: batchSize)
let 𝛁model = autoencoder.gradient { autoencoder -> Tensor in
let image = autoencoder(x)
return meanSquaredError(predicted: image, expected: x)
}
- optimizer.update(&autoencoder.allDifferentiableVariables, along: 𝛁model)
+ optimizer.update(&autoencoder, along: 𝛁model)
}
}
diff --git a/CIFAR/Data.swift b/CIFAR/Data.swift
deleted file mode 100644
index dafac78848a..00000000000
--- a/CIFAR/Data.swift
+++ /dev/null
@@ -1,82 +0,0 @@
-// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-import Python
-import TensorFlow
-
-/// Use Python and shell calls to download and extract the CIFAR-10 tarball if not already done
-/// This can fail for many reasons (e.g. lack of `wget`, `tar`, or an Internet connection)
-func downloadCIFAR10IfNotPresent(to directory: String = ".") {
- let subprocess = Python.import("subprocess")
- let path = Python.import("os.path")
- let filepath = "\(directory)/cifar-10-batches-py"
- let isdir = Bool(path.isdir(filepath))!
- if !isdir {
- print("Downloading CIFAR data...")
- let command = "wget -nv -O- https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz | tar xzf - -C \(directory)"
- subprocess.call(command, shell: true)
- }
-}
-
-struct Example: TensorGroup {
- var label: Tensor
- var data: Tensor
-}
-
-// Each CIFAR data file is provided as a Python pickle of NumPy arrays
-func loadCIFARFile(named name: String, in directory: String = ".") -> Example {
- downloadCIFAR10IfNotPresent(to: directory)
- let np = Python.import("numpy")
- let pickle = Python.import("pickle")
- let path = "\(directory)/cifar-10-batches-py/\(name)"
- let f = Python.open(path, "rb")
- let res = pickle.load(f, encoding: "bytes")
-
- let bytes = res[Python.bytes("data", encoding: "utf8")]
- let labels = res[Python.bytes("labels", encoding: "utf8")]
-
- let labelTensor = Tensor(numpy: np.array(labels))!
- let images = Tensor(numpy: bytes)!
- let imageCount = images.shape[0]
-
- // reshape and transpose from the provided N(CHW) to TF default NHWC
- let imageTensor = Tensor(images
- .reshaped(to: [imageCount, 3, 32, 32])
- .transposed(withPermutations: [0, 2, 3, 1]))
-
- let mean = Tensor([0.485, 0.456, 0.406])
- let std = Tensor([0.229, 0.224, 0.225])
- let imagesNormalized = ((imageTensor / 255.0) - mean) / std
-
- return Example(label: Tensor(labelTensor), data: imagesNormalized)
-}
-
-func loadCIFARTrainingFiles() -> Example {
- let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0)") }
- return Example(
- label: Raw.concat(concatDim: Tensor(0), data.map { $0.label }),
- data: Raw.concat(concatDim: Tensor(0), data.map { $0.data })
- )
-}
-
-func loadCIFARTestFile() -> Example {
- return loadCIFARFile(named: "test_batch")
-}
-
-func loadCIFAR10() -> (
- training: Dataset, test: Dataset) {
- let trainingDataset = Dataset(elements: loadCIFARTrainingFiles())
- let testDataset = Dataset(elements: loadCIFARTestFile())
- return (training: trainingDataset, test: testDataset)
-}
diff --git a/CIFAR/Helpers.swift b/CIFAR/Helpers.swift
deleted file mode 100644
index f4403efb1b5..00000000000
--- a/CIFAR/Helpers.swift
+++ /dev/null
@@ -1,51 +0,0 @@
-// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// TODO: Remove this when it's moved to the standard library.
-extension Array where Element: Differentiable {
- @differentiable(wrt: (self, initialResult), vjp: reduceDerivative)
- func differentiableReduce(
- _ initialResult: Result,
- _ nextPartialResult: @differentiable (Result, Element) -> Result
- ) -> Result {
- return reduce(initialResult, nextPartialResult)
- }
-
- func reduceDerivative(
- _ initialResult: Result,
- _ nextPartialResult: @differentiable (Result, Element) -> Result
- ) -> (Result, (Result.CotangentVector) -> (Array.CotangentVector, Result.CotangentVector)) {
- var pullbacks: [(Result.CotangentVector)
- -> (Result.CotangentVector, Element.CotangentVector)] = []
- let count = self.count
- pullbacks.reserveCapacity(count)
- var result = initialResult
- for element in self {
- let (y, pb) = Swift.valueWithPullback(at: result, element, in: nextPartialResult)
- result = y
- pullbacks.append(pb)
- }
- return (value: result, pullback: { cotangent in
- var resultCotangent = cotangent
- var elementCotangents = CotangentVector([])
- elementCotangents.base.reserveCapacity(count)
- for pullback in pullbacks.reversed() {
- let (newResultCotangent, elementCotangent) = pullback(resultCotangent)
- resultCotangent = newResultCotangent
- elementCotangents.base.append(elementCotangent)
- }
- return (CotangentVector(elementCotangents.base.reversed()), resultCotangent)
- })
- }
-}
diff --git a/CIFAR/README.md b/CIFAR/README.md
deleted file mode 100644
index 85e757c60c3..00000000000
--- a/CIFAR/README.md
+++ /dev/null
@@ -1,23 +0,0 @@
-# CIFAR
-
-This directory contains different example convolutional networks for image
-classification on the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset.
-
-## Setup
-
-You'll need [the latest version][INSTALL] of Swift for TensorFlow
-installed and added to your path. Additionally, the data loader requires Python
-3.x (rather than Python 2.7), `wget`, and `numpy`.
-
-> Note: For macOS, you need to set up the `PYTHON_LIBRARY` to help the Swift for
-> TensorFlow find the `libpython3..dylib` file, e.g., in
-> `homebrew`.
-
-To train the default model, run:
-
-```
-cd swift-models
-swift run -c release CIFAR
-```
-
-[INSTALL]: (https://github.com/tensorflow/swift/blob/master/Installation.md)
diff --git a/CIFAR/ResNet.swift b/CIFAR/ResNet.swift
deleted file mode 100644
index b14fab3d71f..00000000000
--- a/CIFAR/ResNet.swift
+++ /dev/null
@@ -1,119 +0,0 @@
-// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-import TensorFlow
-
-// Original Paper:
-// "Deep Residual Learning for Image Recognition"
-// Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
-// https://arxiv.org/abs/1512.03385
-// using shortcut layer to connect BasicBlock layers (aka Option (B))
-// see https://github.com/akamaster/pytorch_resnet_cifar10 for explanation
-
-struct Conv2DBatchNorm: Layer {
- typealias Input = Tensor
- typealias Output = Tensor
-
- var conv: Conv2D
- var norm: BatchNorm
-
- init(
- filterShape: (Int, Int, Int, Int),
- strides: (Int, Int) = (1, 1)
- ) {
- self.conv = Conv2D(filterShape: filterShape, strides: strides, padding: .same)
- self.norm = BatchNorm(featureCount: filterShape.3)
- }
-
- @differentiable
- func call(_ input: Input) -> Output {
- return input.sequenced(through: conv, norm)
- }
-}
-
-struct BasicBlock: Layer {
- typealias Input = Tensor
- typealias Output = Tensor
-
- var blocks: [Conv2DBatchNorm]
- var shortcut: Conv2DBatchNorm
-
- init(
- featureCounts: (Int, Int),
- kernelSize: Int = 3,
- strides: (Int, Int) = (2, 2),
- blockCount: Int = 3
- ) {
- self.blocks = [Conv2DBatchNorm(
- filterShape: (kernelSize, kernelSize, featureCounts.0, featureCounts.1),
- strides: strides)]
- for _ in 2.. Output {
- let blocksReduced = blocks.differentiableReduce(input) { last, layer in
- relu(layer(last))
- }
- return relu(blocksReduced + shortcut(input))
- }
-}
-
-struct ResNet: Layer {
- typealias Input = Tensor
- typealias Output = Tensor
-
- var inputLayer = Conv2DBatchNorm(filterShape: (3, 3, 3, 16))
-
- var basicBlock1: BasicBlock
- var basicBlock2: BasicBlock
- var basicBlock3: BasicBlock
-
- init(blockCount: Int = 3) {
- basicBlock1 = BasicBlock(featureCounts:(16, 16), strides: (1, 1), blockCount: blockCount)
- basicBlock2 = BasicBlock(featureCounts:(16, 32), blockCount: blockCount)
- basicBlock3 = BasicBlock(featureCounts:(32, 64), blockCount: blockCount)
- }
-
- var averagePool = AvgPool2D(poolSize: (8, 8), strides: (8, 8))
- var flatten = Flatten()
- var classifier = Dense(inputSize: 64, outputSize: 10, activation: softmax)
-
- @differentiable
- func call(_ input: Input) -> Output {
- let tmp = relu(inputLayer(input))
- let convolved = tmp.sequenced(through: basicBlock1, basicBlock2, basicBlock3)
- return convolved.sequenced(through: averagePool, flatten, classifier)
- }
-}
-
-extension ResNet {
- enum Kind: Int {
- case resNet20 = 3
- case resNet32 = 5
- case resNet44 = 7
- case resNet56 = 9
- case resNet110 = 18
- }
-
- init(kind: Kind) {
- self.init(blockCount: kind.rawValue)
- }
-}
diff --git a/CIFAR/WideResNet.swift b/CIFAR/WideResNet.swift
deleted file mode 100644
index a58b0529714..00000000000
--- a/CIFAR/WideResNet.swift
+++ /dev/null
@@ -1,177 +0,0 @@
-// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-import TensorFlow
-
-// Original Paper:
-// "Wide Residual Networks"
-// Sergey Zagoruyko, Nikos Komodakis
-// https://arxiv.org/abs/1605.07146
-// https://github.com/szagoruyko/wide-residual-networks
-
-struct BatchNormConv2DBlock: Layer {
- typealias Input = Tensor
- typealias Output = Tensor
-
- var norm1: BatchNorm
- var conv1: Conv2D
- var norm2: BatchNorm
- var conv2: Conv2D
-
- init(
- filterShape: (Int, Int, Int, Int),
- strides: (Int, Int) = (1, 1),
- padding: Padding = .same
- ) {
- self.norm1 = BatchNorm(featureCount: filterShape.2)
- self.conv1 = Conv2D(filterShape: filterShape, strides: strides, padding: padding)
- self.norm2 = BatchNorm(featureCount: filterShape.3)
- self.conv2 = Conv2D(filterShape: filterShape, strides: (1, 1), padding: padding)
- }
-
- @differentiable
- func call(_ input: Input) -> Output {
- let firstLayer = conv1(relu(norm1(input)))
- return conv2(relu(norm2(firstLayer)))
- }
-}
-
-struct WideResNetBasicBlock: Layer {
- typealias Input = Tensor
- typealias Output = Tensor
-
- var blocks: [BatchNormConv2DBlock]
- var shortcut: Conv2D
-
- init(
- featureCounts: (Int, Int),
- kernelSize: Int = 3,
- depthFactor: Int = 2,
- widenFactor: Int = 1,
- initialStride: (Int, Int) = (2, 2)
- ) {
- if initialStride == (1, 1) {
- self.blocks = [BatchNormConv2DBlock(
- filterShape: (kernelSize, kernelSize,
- featureCounts.0, featureCounts.1 * widenFactor),
- strides: initialStride)]
- self.shortcut = Conv2D(
- filterShape: (1, 1, featureCounts.0, featureCounts.1 * widenFactor),
- strides: initialStride)
- } else {
- self.blocks = [BatchNormConv2DBlock(
- filterShape: (kernelSize, kernelSize,
- featureCounts.0 * widenFactor, featureCounts.1 * widenFactor),
- strides: initialStride)]
- self.shortcut = Conv2D(
- filterShape: (1, 1, featureCounts.0 * widenFactor, featureCounts.1 * widenFactor),
- strides: initialStride)
- }
- for _ in 1.. Output {
- let blocksReduced = blocks.differentiableReduce(input) { last, layer in
- relu(layer(last))
- }
- return relu(blocksReduced + shortcut(input))
- }
-}
-
-struct WideResNet: Layer {
- typealias Input = Tensor
- typealias Output = Tensor
-
- var l1: Conv2D
-
- var l2: WideResNetBasicBlock
- var l3: WideResNetBasicBlock
- var l4: WideResNetBasicBlock
-
- var norm: BatchNorm
- var avgPool: AvgPool2D
- var flatten = Flatten()
- var classifier: Dense
-
- init(depthFactor: Int = 2, widenFactor: Int = 8) {
- self.l1 = Conv2D(filterShape: (3, 3, 3, 16), strides: (1, 1), padding: .same)
-
- l2 = WideResNetBasicBlock(featureCounts: (16, 16), depthFactor: depthFactor,
- widenFactor: widenFactor, initialStride: (1, 1))
- l3 = WideResNetBasicBlock(featureCounts: (16, 32), depthFactor: depthFactor,
- widenFactor: widenFactor)
- l4 = WideResNetBasicBlock(featureCounts: (32, 64), depthFactor: depthFactor,
- widenFactor: widenFactor)
-
- self.norm = BatchNorm(featureCount: 64 * widenFactor)
- self.avgPool = AvgPool2D(poolSize: (8, 8), strides: (8, 8))
- self.classifier = Dense(inputSize: 64 * widenFactor, outputSize: 10)
- }
-
- @differentiable
- func call(_ input: Input) -> Output {
- let inputLayer = input.sequenced(through: l1, l2, l3, l4)
- let finalNorm = relu(norm(inputLayer))
- return finalNorm.sequenced(through: avgPool, flatten, classifier)
- }
-}
-
-extension WideResNet {
- enum Kind {
- case wideResNet16
- case wideResNet16k8
- case wideResNet16k10
- case wideResNet22
- case wideResNet22k8
- case wideResNet22k10
- case wideResNet28
- case wideResNet28k10
- case wideResNet28k12
- case wideResNet40k1
- case wideResNet40k2
- case wideResNet40k4
- case wideResNet40k8
- }
-
- init(kind: Kind) {
- switch kind {
- case .wideResNet16, .wideResNet16k8:
- self.init(depthFactor: 2, widenFactor: 8)
- case .wideResNet16k10:
- self.init(depthFactor: 2, widenFactor: 10)
- case .wideResNet22, .wideResNet22k8:
- self.init(depthFactor: 3, widenFactor: 8)
- case .wideResNet22k10:
- self.init(depthFactor: 3, widenFactor: 10)
- case .wideResNet28, .wideResNet28k10:
- self.init(depthFactor: 4, widenFactor: 10)
- case .wideResNet28k12:
- self.init(depthFactor: 4, widenFactor: 12)
- case .wideResNet40k1:
- self.init(depthFactor: 6, widenFactor: 1)
- case .wideResNet40k2:
- self.init(depthFactor: 6, widenFactor: 2)
- case .wideResNet40k4:
- self.init(depthFactor: 6, widenFactor: 4)
- case .wideResNet40k8:
- self.init(depthFactor: 6, widenFactor: 8)
- }
- }
-}
diff --git a/Catch/README.md b/Catch/README.md
index 566f7d449e1..136be5630c0 100644
--- a/Catch/README.md
+++ b/Catch/README.md
@@ -23,5 +23,5 @@ installed. Make sure you've added the correct version of `swift` to your path.
To train the model, run:
```
-swift -O catch.swift
+swift -O main.swift
```
diff --git a/Catch/main.swift b/Catch/main.swift
index a6192e6df3c..2ae144769b0 100644
--- a/Catch/main.swift
+++ b/Catch/main.swift
@@ -43,30 +43,20 @@ protocol Agent: AnyObject {
func step(observation: Observation, reward: Reward) -> Action
}
-struct Model: Layer {
- typealias Input = Tensor
- typealias Output = Tensor
-
- var layer1 = Dense(inputSize: 3, outputSize: 50, activation: sigmoid,
- generator: &rng)
- var layer2 = Dense(inputSize: 50, outputSize: 3, activation: sigmoid,
- generator: &rng)
-
- @differentiable
- func call(_ input: Input) -> Output {
- return input.sequenced(through: layer1, layer2)
- }
-}
-
class CatchAgent: Agent {
typealias Action = CatchAction
- var model: Model = Model()
- let optimizer: Adam
+ var model = Sequential {
+ Dense(inputSize: 3, outputSize: 50, activation: sigmoid)
+ Dense(inputSize: 50, outputSize: 3, activation: sigmoid)
+ }
+
+ var learningRate: Float
+ lazy var optimizer = Adam(for: self.model, learningRate: self.learningRate)
var previousReward: Reward
init(initialReward: Reward, learningRate: Float) {
- optimizer = Adam(for: model, learningRate: learningRate)
+ self.learningRate = learningRate
previousReward = initialReward
}
}
@@ -83,9 +73,9 @@ extension CatchAgent {
let (ŷ, backprop) = model.appliedForBackpropagation(to: x)
let maxIndex = ŷ.argmax().scalarized()
- let 𝛁loss = -log(Tensor(ŷ.max())).broadcast(like: ŷ) * previousReward
+ let 𝛁loss = -log(Tensor(ŷ.max())).broadcasted(like: ŷ) * previousReward
let (𝛁model, _) = backprop(𝛁loss)
- optimizer.update(&model.allDifferentiableVariables, along: 𝛁model)
+ optimizer.update(&model, along: 𝛁model)
return CatchAction(rawValue: Int(maxIndex))!
}
diff --git a/Datasets/CIFAR10/CIFAR10.swift b/Datasets/CIFAR10/CIFAR10.swift
new file mode 100644
index 00000000000..7b74380adbd
--- /dev/null
+++ b/Datasets/CIFAR10/CIFAR10.swift
@@ -0,0 +1,135 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Original source:
+// "The CIFAR-10 dataset"
+// Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton.
+// https://www.cs.toronto.edu/~kriz/cifar.html
+
+import Foundation
+import TensorFlow
+
+#if canImport(FoundationNetworking)
+ import FoundationNetworking
+#endif
+
+public struct CIFAR10 {
+ public let trainingDataset: Dataset
+ public let testDataset: Dataset
+
+ public init() {
+ self.trainingDataset = Dataset(elements: loadCIFARTrainingFiles())
+ self.testDataset = Dataset(elements: loadCIFARTestFile())
+ }
+}
+
+func downloadCIFAR10IfNotPresent(to directory: String = ".") {
+ let downloadPath = "\(directory)/cifar-10-batches-bin"
+ let directoryExists = FileManager.default.fileExists(atPath: downloadPath)
+
+ guard !directoryExists else { return }
+
+ print("Downloading CIFAR dataset...")
+ let archivePath = "\(directory)/cifar-10-binary.tar.gz"
+ let archiveExists = FileManager.default.fileExists(atPath: archivePath)
+ if !archiveExists {
+ print("Archive missing, downloading...")
+ do {
+ let downloadedFile = try Data(
+ contentsOf: URL(
+ string: "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz")!)
+ try downloadedFile.write(to: URL(fileURLWithPath: archivePath))
+ } catch {
+ print("Could not download CIFAR dataset, error: \(error)")
+ exit(-1)
+ }
+ }
+
+ print("Archive downloaded, processing...")
+
+ #if os(macOS)
+ let tarLocation = "/usr/bin/tar"
+ #else
+ let tarLocation = "/bin/tar"
+ #endif
+
+ let task = Process()
+ task.executableURL = URL(fileURLWithPath: tarLocation)
+ task.arguments = ["xzf", archivePath]
+ do {
+ try task.run()
+ task.waitUntilExit()
+ } catch {
+ print("CIFAR extraction failed with error: \(error)")
+ }
+
+ do {
+ try FileManager.default.removeItem(atPath: archivePath)
+ } catch {
+ print("Could not remove archive, error: \(error)")
+ exit(-1)
+ }
+
+ print("Unarchiving completed")
+}
+
+func loadCIFARFile(named name: String, in directory: String = ".") -> CIFARExample {
+ downloadCIFAR10IfNotPresent(to: directory)
+ let path = "\(directory)/cifar-10-batches-bin/\(name)"
+
+ let imageCount = 10000
+ guard let fileContents = try? Data(contentsOf: URL(fileURLWithPath: path)) else {
+ print("Could not read dataset file: \(name)")
+ exit(-1)
+ }
+ guard fileContents.count == 30_730_000 else {
+ print(
+ "Dataset file \(name) should have 30730000 bytes, instead had \(fileContents.count)")
+ exit(-1)
+ }
+
+ var bytes: [UInt8] = []
+ var labels: [Int64] = []
+
+ let imageByteSize = 3073
+ for imageIndex in 0..(shape: [imageCount], scalars: labels)
+ let images = Tensor(shape: [imageCount, 3, 32, 32], scalars: bytes)
+
+ // Transpose from the CIFAR-provided N(CHW) to TF's default NHWC.
+ let imageTensor = Tensor(images.transposed(withPermutations: [0, 2, 3, 1]))
+
+ let mean = Tensor([0.485, 0.456, 0.406])
+ let std = Tensor([0.229, 0.224, 0.225])
+ let imagesNormalized = ((imageTensor / 255.0) - mean) / std
+
+ return CIFARExample(label: Tensor(labelTensor), data: imagesNormalized)
+}
+
+func loadCIFARTrainingFiles() -> CIFARExample {
+ let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0).bin") }
+ return CIFARExample(
+ label: Raw.concat(concatDim: Tensor(0), data.map { $0.label }),
+ data: Raw.concat(concatDim: Tensor(0), data.map { $0.data })
+ )
+}
+
+func loadCIFARTestFile() -> CIFARExample {
+ return loadCIFARFile(named: "test_batch.bin")
+}
diff --git a/Datasets/CIFAR10/CIFARExample.swift b/Datasets/CIFAR10/CIFARExample.swift
new file mode 100644
index 00000000000..ac9bd888609
--- /dev/null
+++ b/Datasets/CIFAR10/CIFARExample.swift
@@ -0,0 +1,35 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import TensorFlow
+
+public struct CIFARExample: TensorGroup {
+ public var label: Tensor
+ public var data: Tensor
+
+ public init(label: Tensor, data: Tensor) {
+ self.label = label
+ self.data = data
+ }
+
+ public init(
+ _handles: C
+ ) where C.Element: _AnyTensorHandle {
+ precondition(_handles.count == 2)
+ let labelIndex = _handles.startIndex
+ let dataIndex = _handles.index(labelIndex, offsetBy: 1)
+ label = Tensor(handle: TensorHandle(handle: _handles[labelIndex]))
+ data = Tensor(handle: TensorHandle(handle: _handles[dataIndex]))
+ }
+}
diff --git a/Datasets/MNIST/MNIST.swift b/Datasets/MNIST/MNIST.swift
new file mode 100644
index 00000000000..18256088a3f
--- /dev/null
+++ b/Datasets/MNIST/MNIST.swift
@@ -0,0 +1,109 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Original source:
+// "The MNIST database of handwritten digits"
+// Yann LeCun, Corinna Cortes, and Christopher J.C. Burges
+// http://yann.lecun.com/exdb/mnist/
+
+import Foundation
+import TensorFlow
+
+public struct MNIST {
+ public let trainingImages: Tensor
+ public let trainingLabels: Tensor
+ public let testImages: Tensor
+ public let testLabels: Tensor
+
+ public let trainingSize: Int
+ public let testSize: Int
+
+ public let batchSize: Int
+
+ public init(batchSize: Int, flattening: Bool = false, normalizing: Bool = false) {
+ self.batchSize = batchSize
+
+ let (trainingImages, trainingLabels) = readMNIST(
+ imagesFile: "train-images-idx3-ubyte",
+ labelsFile: "train-labels-idx1-ubyte",
+ flattening: flattening,
+ normalizing: normalizing)
+ self.trainingImages = trainingImages
+ self.trainingLabels = trainingLabels
+ self.trainingSize = Int(trainingLabels.shape[0])
+
+ let (testImages, testLabels) = readMNIST(
+ imagesFile: "t10k-images-idx3-ubyte",
+ labelsFile: "t10k-labels-idx1-ubyte",
+ flattening: flattening,
+ normalizing: normalizing)
+ self.testImages = testImages
+ self.testLabels = testLabels
+ self.testSize = Int(testLabels.shape[0])
+ }
+}
+
+extension Tensor {
+ public func minibatch(at index: Int, batchSize: Int) -> Tensor {
+ let start = index * batchSize
+ return self[start.. [UInt8] {
+ for folder in possibleDirectories {
+ let parent = URL(fileURLWithPath: folder)
+ let filePath = parent.appendingPathComponent(path)
+ guard FileManager.default.fileExists(atPath: filePath.path) else {
+ continue
+ }
+ let data = try! Data(contentsOf: filePath, options: [])
+ return [UInt8](data)
+ }
+ print("File not found: \(path)")
+ exit(-1)
+}
+
+/// Reads MNIST images and labels from specified file paths.
+func readMNIST(imagesFile: String, labelsFile: String, flattening: Bool, normalizing: Bool) -> (
+ images: Tensor,
+ labels: Tensor
+) {
+ print("Reading data from files: \(imagesFile), \(labelsFile).")
+ let images = readFile(imagesFile, possibleDirectories: [".", "./Datasets/MNIST"]).dropFirst(16)
+ .map(Float.init)
+ let labels = readFile(labelsFile, possibleDirectories: [".", "./Datasets/MNIST"]).dropFirst(8)
+ .map(Int32.init)
+ let rowCount = labels.count
+ let imageHeight = 28
+ let imageWidth = 28
+
+ print("Constructing data tensors.")
+
+ if flattening {
+ var flattenedImages = Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images)
+ / 255.0
+ if normalizing {
+ flattenedImages = flattenedImages * 2.0 - 1.0
+ }
+ return (images: flattenedImages, labels: Tensor(labels))
+ } else {
+ return (
+ images: Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images)
+ .transposed(withPermutations: [0, 2, 3, 1]) / 255, // NHWC
+ labels: Tensor(labels)
+ )
+ }
+}
diff --git a/Datasets/MNIST/t10k-images-idx3-ubyte b/Datasets/MNIST/t10k-images-idx3-ubyte
new file mode 100644
index 00000000000..1170b2cae98
Binary files /dev/null and b/Datasets/MNIST/t10k-images-idx3-ubyte differ
diff --git a/Datasets/MNIST/t10k-labels-idx1-ubyte b/Datasets/MNIST/t10k-labels-idx1-ubyte
new file mode 100644
index 00000000000..d1c3a970612
Binary files /dev/null and b/Datasets/MNIST/t10k-labels-idx1-ubyte differ
diff --git a/Autoencoder/Resources/train-images-idx3-ubyte b/Datasets/MNIST/train-images-idx3-ubyte
similarity index 100%
rename from Autoencoder/Resources/train-images-idx3-ubyte
rename to Datasets/MNIST/train-images-idx3-ubyte
diff --git a/Autoencoder/Resources/train-labels-idx1-ubyte b/Datasets/MNIST/train-labels-idx1-ubyte
similarity index 100%
rename from Autoencoder/Resources/train-labels-idx1-ubyte
rename to Datasets/MNIST/train-labels-idx1-ubyte
diff --git a/CIFAR/Models.swift b/Examples/Custom-CIFAR10/Models.swift
similarity index 96%
rename from CIFAR/Models.swift
rename to Examples/Custom-CIFAR10/Models.swift
index 516b9f14838..4608beada6c 100644
--- a/CIFAR/Models.swift
+++ b/Examples/Custom-CIFAR10/Models.swift
@@ -29,7 +29,7 @@ struct PyTorchModel: Layer {
var dense3 = Dense(inputSize: 84, outputSize: 10, activation: identity)
@differentiable
- func call(_ input: Input) -> Output {
+ func callAsFunction(_ input: Input) -> Output {
let convolved = input.sequenced(through: conv1, pool1, conv2, pool2)
return convolved.sequenced(through: flatten, dense1, dense2, dense3)
}
@@ -54,7 +54,7 @@ struct KerasModel: Layer {
var dense2 = Dense(inputSize: 512, outputSize: 10, activation: identity)
@differentiable
- func call(_ input: Input) -> Output {
+ func callAsFunction(_ input: Input) -> Output {
let conv1 = input.sequenced(through: conv1a, conv1b, pool1, dropout1)
let conv2 = conv1.sequenced(through: conv2a, conv2b, pool2, dropout2)
return conv2.sequenced(through: flatten, dense1, dropout3, dense2)
diff --git a/Examples/Custom-CIFAR10/README.md b/Examples/Custom-CIFAR10/README.md
new file mode 100644
index 00000000000..e09baafa07b
--- /dev/null
+++ b/Examples/Custom-CIFAR10/README.md
@@ -0,0 +1,18 @@
+# CIFAR-10 with custom models
+
+This example demonstrates how to train the custom-defined models (based on examples from [PyTorch](https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html) and [Keras](https://github.com/keras-team/keras/blob/master/examples/cifar10_cnn.py) ) against the [CIFAR-10 image classification dataset](https://www.cs.toronto.edu/~kriz/cifar.html).
+
+Two custom models are defined, and one is applied to an instance of the CIFAR-10 dataset. A custom training loop is defined, and the training and test losses and accuracies for each epoch are shown during training.
+
+## Setup
+
+To begin, you'll need the [latest version of Swift for
+TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
+installed. Make sure you've added the correct version of `swift` to your path.
+
+To train the model, run:
+
+```sh
+cd swift-models
+swift run -c release Custom-CIFAR10
+```
diff --git a/CIFAR/main.swift b/Examples/Custom-CIFAR10/main.swift
similarity index 83%
rename from CIFAR/main.swift
rename to Examples/Custom-CIFAR10/main.swift
index b47222b59cb..be27fff5a59 100644
--- a/CIFAR/main.swift
+++ b/Examples/Custom-CIFAR10/main.swift
@@ -12,14 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+import Datasets
import TensorFlow
-import Python
-PythonLibrary.useVersion(3)
let batchSize = 100
-let cifarDataset = loadCIFAR10()
-let testBatches = cifarDataset.test.batched(batchSize)
+let dataset = CIFAR10()
+let testBatches = dataset.testDataset.batched(batchSize)
var model = KerasModel()
let optimizer = RMSProp(for: model, learningRate: 0.0001, decay: 1e-6)
@@ -30,7 +29,7 @@ Context.local.learningPhase = .training
for epoch in 1...100 {
var trainingLossSum: Float = 0
var trainingBatchCount = 0
- let trainingShuffled = cifarDataset.training.shuffled(
+ let trainingShuffled = dataset.trainingDataset.shuffled(
sampleCount: 50000, randomSeed: Int64(epoch))
for batch in trainingShuffled.batched(batchSize) {
let (labels, images) = (batch.label, batch.data)
@@ -40,7 +39,7 @@ for epoch in 1...100 {
}
trainingLossSum += loss.scalarized()
trainingBatchCount += 1
- optimizer.update(&model.allDifferentiableVariables, along: gradients)
+ optimizer.update(&model, along: gradients)
}
var testLossSum: Float = 0
@@ -54,15 +53,17 @@ for epoch in 1...100 {
testBatchCount += 1
let correctPredictions = logits.argmax(squeezingAxis: 1) .== labels
- correctGuessCount = correctGuessCount +
- Int(Tensor(correctPredictions).sum().scalarized())
+ correctGuessCount = correctGuessCount + Int(
+ Tensor(correctPredictions).sum().scalarized())
totalGuessCount = totalGuessCount + batchSize
}
let accuracy = Float(correctGuessCount) / Float(totalGuessCount)
- print("""
+ print(
+ """
[Epoch \(epoch)] \
Accuracy: \(correctGuessCount)/\(totalGuessCount) (\(accuracy)) \
Loss: \(testLossSum / Float(testBatchCount))
- """)
+ """
+ )
}
diff --git a/Examples/LeNet-MNIST/README.md b/Examples/LeNet-MNIST/README.md
new file mode 100644
index 00000000000..e88903c1e2c
--- /dev/null
+++ b/Examples/LeNet-MNIST/README.md
@@ -0,0 +1,19 @@
+# LeNet-5 with MNIST
+
+This example demonstrates how to train the [LeNet-5 network]( http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf) against the [MNIST digit classification dataset](http://yann.lecun.com/exdb/mnist/).
+
+The LeNet network is instantiated from the ImageClassificationModels library of standard models, and applied to an instance of the MNIST dataset. A custom training loop is defined, and the training and test losses and accuracies for each epoch are shown during training.
+
+
+## Setup
+
+To begin, you'll need the [latest version of Swift for
+TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
+installed. Make sure you've added the correct version of `swift` to your path.
+
+To train the model, run:
+
+```sh
+cd swift-models
+swift run -c release LeNet-MNIST
+```
diff --git a/Examples/LeNet-MNIST/main.swift b/Examples/LeNet-MNIST/main.swift
new file mode 100644
index 00000000000..77770eaa3e1
--- /dev/null
+++ b/Examples/LeNet-MNIST/main.swift
@@ -0,0 +1,91 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import TensorFlow
+import Datasets
+
+let epochCount = 12
+let batchSize = 128
+
+let dataset = MNIST(batchSize: batchSize)
+// The LeNet-5 model, equivalent to `LeNet` in `ImageClassificationModels`.
+var classifier = Sequential {
+ Conv2D(filterShape: (5, 5, 1, 6), padding: .same, activation: relu)
+ AvgPool2D(poolSize: (2, 2), strides: (2, 2))
+ Conv2D(filterShape: (5, 5, 6, 16), activation: relu)
+ AvgPool2D(poolSize: (2, 2), strides: (2, 2))
+ Flatten()
+ Dense(inputSize: 400, outputSize: 120, activation: relu)
+ Dense(inputSize: 120, outputSize: 84, activation: relu)
+ Dense(inputSize: 84, outputSize: 10, activation: softmax)
+}
+
+let optimizer = SGD(for: classifier, learningRate: 0.1)
+
+print("Beginning training...")
+
+struct Statistics {
+ var correctGuessCount: Int = 0
+ var totalGuessCount: Int = 0
+ var totalLoss: Float = 0
+}
+
+// The training loop.
+for epoch in 1...epochCount {
+ var trainStats = Statistics()
+ var testStats = Statistics()
+ Context.local.learningPhase = .training
+ for i in 0 ..< dataset.trainingSize / batchSize {
+ let x = dataset.trainingImages.minibatch(at: i, batchSize: batchSize)
+ let y = dataset.trainingLabels.minibatch(at: i, batchSize: batchSize)
+ // Compute the gradient with respect to the model.
+ let 𝛁model = classifier.gradient { classifier -> Tensor in
+ let ŷ = classifier(x)
+ let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== y
+ trainStats.correctGuessCount += Int(
+ Tensor(correctPredictions).sum().scalarized())
+ trainStats.totalGuessCount += batchSize
+ let loss = softmaxCrossEntropy(logits: ŷ, labels: y)
+ trainStats.totalLoss += loss.scalarized()
+ return loss
+ }
+ // Update the model's differentiable variables along the gradient vector.
+ optimizer.update(&classifier, along: 𝛁model)
+ }
+
+ Context.local.learningPhase = .inference
+ for i in 0 ..< dataset.testSize / batchSize {
+ let x = dataset.testImages.minibatch(at: i, batchSize: batchSize)
+ let y = dataset.testLabels.minibatch(at: i, batchSize: batchSize)
+ // Compute loss on test set
+ let ŷ = classifier(x)
+ let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== y
+ testStats.correctGuessCount += Int(Tensor(correctPredictions).sum().scalarized())
+ testStats.totalGuessCount += batchSize
+ let loss = softmaxCrossEntropy(logits: ŷ, labels: y)
+ testStats.totalLoss += loss.scalarized()
+ }
+
+ let trainAccuracy = Float(trainStats.correctGuessCount) / Float(trainStats.totalGuessCount)
+ let testAccuracy = Float(testStats.correctGuessCount) / Float(testStats.totalGuessCount)
+ print("""
+ [Epoch \(epoch)] \
+ Training Loss: \(trainStats.totalLoss), \
+ Training Accuracy: \(trainStats.correctGuessCount)/\(trainStats.totalGuessCount) \
+ (\(trainAccuracy)), \
+ Test Loss: \(testStats.totalLoss), \
+ Test Accuracy: \(testStats.correctGuessCount)/\(testStats.totalGuessCount) \
+ (\(testAccuracy))
+ """)
+}
diff --git a/Examples/ResNet-CIFAR10/README.md b/Examples/ResNet-CIFAR10/README.md
new file mode 100644
index 00000000000..de69a07d66a
--- /dev/null
+++ b/Examples/ResNet-CIFAR10/README.md
@@ -0,0 +1,18 @@
+# ResNet-50 with CIFAR-10
+
+This example demonstrates how to train the [ResNet-50 network]( https://arxiv.org/abs/1512.03385) against the [CIFAR-10 image classification dataset](https://www.cs.toronto.edu/~kriz/cifar.html).
+
+A modified ResNet-50 network is instantiated from the ImageClassificationModels library of standard models, and applied to an instance of the CIFAR-10 dataset. A custom training loop is defined, and the training and test losses and accuracies for each epoch are shown during training.
+
+## Setup
+
+To begin, you'll need the [latest version of Swift for
+TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
+installed. Make sure you've added the correct version of `swift` to your path.
+
+To train the model, run:
+
+```sh
+cd swift-models
+swift run -c release ResNet-CIFAR10
+```
diff --git a/ResNet/main.swift b/Examples/ResNet-CIFAR10/main.swift
similarity index 78%
rename from ResNet/main.swift
rename to Examples/ResNet-CIFAR10/main.swift
index ee00f3ba9e4..edc1ce08488 100644
--- a/ResNet/main.swift
+++ b/Examples/ResNet-CIFAR10/main.swift
@@ -12,18 +12,17 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+import Datasets
+import ImageClassificationModels
import TensorFlow
-import Python
-PythonLibrary.useVersion(3)
let batchSize = 100
-let cifarDataset = loadCIFAR10()
-let testBatches = cifarDataset.test.batched(batchSize)
+let dataset = CIFAR10()
+let testBatches = dataset.testDataset.batched(batchSize)
-// ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
-// PreActivatedResNet18, PreActivatedResNet34
-var model = ResNet50(imageSize: 32, classCount: 10) // Use the network sized for CIFAR-10
+// Use the network sized for CIFAR-10
+var model = ResNet(inputKind: .resNet50, dataKind: .cifar)
// the classic ImageNet optimizer setting diverges on CIFAR-10
// let optimizer = SGD(for: model, learningRate: 0.1, momentum: 0.9)
@@ -35,7 +34,7 @@ Context.local.learningPhase = .training
for epoch in 1...10 {
var trainingLossSum: Float = 0
var trainingBatchCount = 0
- let trainingShuffled = cifarDataset.training.shuffled(
+ let trainingShuffled = dataset.trainingDataset.shuffled(
sampleCount: 50000, randomSeed: Int64(epoch))
for batch in trainingShuffled.batched(batchSize) {
let (labels, images) = (batch.label, batch.data)
@@ -45,7 +44,7 @@ for epoch in 1...10 {
}
trainingLossSum += loss.scalarized()
trainingBatchCount += 1
- optimizer.update(&model.allDifferentiableVariables, along: gradients)
+ optimizer.update(&model, along: gradients)
}
var testLossSum: Float = 0
var testBatchCount = 0
@@ -58,15 +57,17 @@ for epoch in 1...10 {
testBatchCount += 1
let correctPredictions = logits.argmax(squeezingAxis: 1) .== labels
- correctGuessCount = correctGuessCount +
- Int(Tensor(correctPredictions).sum().scalarized())
+ correctGuessCount = correctGuessCount + Int(
+ Tensor(correctPredictions).sum().scalarized())
totalGuessCount = totalGuessCount + batchSize
}
let accuracy = Float(correctGuessCount) / Float(totalGuessCount)
- print("""
+ print(
+ """
[Epoch \(epoch)] \
Accuracy: \(correctGuessCount)/\(totalGuessCount) (\(accuracy)) \
Loss: \(testLossSum / Float(testBatchCount))
- """)
+ """
+ )
}
diff --git a/MNIST/README.md b/GAN/README.md
similarity index 51%
rename from MNIST/README.md
rename to GAN/README.md
index bd735738eb6..0482aa75999 100644
--- a/MNIST/README.md
+++ b/GAN/README.md
@@ -1,7 +1,14 @@
-# MNIST
+# Simple GAN
-This directory builds a simple convolutional neural network to classify the
-[MNIST dataset](http://yann.lecun.com/exdb/mnist/).
+After Epoch 1:
+
+
+
+
+After Epoch 10:
+
+
+
## Setup
@@ -11,7 +18,6 @@ installed. Make sure you've added the correct version of `swift` to your path.
To train the model, run:
-```
-cd swift-models
-swift run -c release MNIST
+```sh
+swift run GAN
```
diff --git a/GAN/images/epoch-1-output.png b/GAN/images/epoch-1-output.png
new file mode 100644
index 00000000000..b7fd99b7295
Binary files /dev/null and b/GAN/images/epoch-1-output.png differ
diff --git a/GAN/images/epoch-10-output.png b/GAN/images/epoch-10-output.png
new file mode 100644
index 00000000000..66669272db4
Binary files /dev/null and b/GAN/images/epoch-10-output.png differ
diff --git a/GAN/main.swift b/GAN/main.swift
new file mode 100644
index 00000000000..f57a550274f
--- /dev/null
+++ b/GAN/main.swift
@@ -0,0 +1,188 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import Datasets
+import Foundation
+import ModelSupport
+import TensorFlow
+
+let epochCount = 10
+let batchSize = 32
+let outputFolder = "./output/"
+let imageHeight = 28
+let imageWidth = 28
+let imageSize = imageHeight * imageWidth
+let latentSize = 64
+
+// Models
+
+struct Generator: Layer {
+ var dense1 = Dense(
+ inputSize: latentSize, outputSize: latentSize * 2,
+ activation: { leakyRelu($0) })
+
+ var dense2 = Dense(
+ inputSize: latentSize * 2, outputSize: latentSize * 4,
+ activation: { leakyRelu($0) })
+
+ var dense3 = Dense(
+ inputSize: latentSize * 4, outputSize: latentSize * 8,
+ activation: { leakyRelu($0) })
+
+ var dense4 = Dense(
+ inputSize: latentSize * 8, outputSize: imageSize,
+ activation: tanh)
+
+ var batchnorm1 = BatchNorm(featureCount: latentSize * 2)
+ var batchnorm2 = BatchNorm(featureCount: latentSize * 4)
+ var batchnorm3 = BatchNorm(featureCount: latentSize * 8)
+
+ @differentiable
+ func callAsFunction(_ input: Tensor) -> Tensor {
+ let x1 = batchnorm1(dense1(input))
+ let x2 = batchnorm2(dense2(x1))
+ let x3 = batchnorm3(dense3(x2))
+ return dense4(x3)
+ }
+}
+
+struct Discriminator: Layer {
+ var dense1 = Dense(
+ inputSize: imageSize, outputSize: 256,
+ activation: { leakyRelu($0) })
+
+ var dense2 = Dense(
+ inputSize: 256, outputSize: 64,
+ activation: { leakyRelu($0) })
+
+ var dense3 = Dense(
+ inputSize: 64, outputSize: 16,
+ activation: { leakyRelu($0) })
+
+ var dense4 = Dense(
+ inputSize: 16, outputSize: 1,
+ activation: identity)
+
+ @differentiable
+ func callAsFunction(_ input: Tensor) -> Tensor {
+ input.sequenced(through: dense1, dense2, dense3, dense4)
+ }
+}
+
+// Loss functions
+
+@differentiable
+func generatorLoss(fakeLogits: Tensor) -> Tensor {
+ sigmoidCrossEntropy(
+ logits: fakeLogits,
+ labels: Tensor(ones: fakeLogits.shape))
+}
+
+@differentiable
+func discriminatorLoss(realLogits: Tensor, fakeLogits: Tensor) -> Tensor {
+ let realLoss = sigmoidCrossEntropy(
+ logits: realLogits,
+ labels: Tensor(ones: realLogits.shape))
+ let fakeLoss = sigmoidCrossEntropy(
+ logits: fakeLogits,
+ labels: Tensor(zeros: fakeLogits.shape))
+ return realLoss + fakeLoss
+}
+
+/// Returns `size` samples of noise vector.
+func sampleVector(size: Int) -> Tensor {
+ Tensor(randomNormal: [size, latentSize])
+}
+
+let dataset = MNIST(batchSize: batchSize, flattening: true, normalizing: true)
+
+var generator = Generator()
+var discriminator = Discriminator()
+
+let optG = Adam(for: generator, learningRate: 2e-4, beta1: 0.5)
+let optD = Adam(for: discriminator, learningRate: 2e-4, beta1: 0.5)
+
+// Noise vectors and plot function for testing
+let testImageGridSize = 4
+let testVector = sampleVector(size: testImageGridSize * testImageGridSize)
+
+func saveImageGrid(_ testImage: Tensor, name: String) throws {
+ var gridImage = testImage.reshaped(
+ to: [
+ testImageGridSize, testImageGridSize,
+ imageHeight, imageWidth,
+ ])
+ // Add padding.
+ gridImage = gridImage.padded(forSizes: [(0, 0), (0, 0), (1, 1), (1, 1)], with: 1)
+ // Transpose to create single image.
+ gridImage = gridImage.transposed(withPermutations: [0, 2, 1, 3])
+ gridImage = gridImage.reshaped(
+ to: [
+ (imageHeight + 2) * testImageGridSize,
+ (imageWidth + 2) * testImageGridSize,
+ ])
+ // Convert [-1, 1] range to [0, 1] range.
+ gridImage = (gridImage + 1) / 2
+
+ try saveImage(
+ gridImage, size: (gridImage.shape[0], gridImage.shape[1]), directory: outputFolder,
+ name: name)
+}
+
+print("Start training...")
+
+// Start training loop.
+for epoch in 1...epochCount {
+ // Start training phase.
+ Context.local.learningPhase = .training
+ for i in 0 ..< dataset.trainingSize / batchSize {
+ // Perform alternative update.
+ // Update generator.
+ let vec1 = sampleVector(size: batchSize)
+
+ let 𝛁generator = generator.gradient { generator -> Tensor in
+ let fakeImages = generator(vec1)
+ let fakeLogits = discriminator(fakeImages)
+ let loss = generatorLoss(fakeLogits: fakeLogits)
+ return loss
+ }
+ optG.update(&generator, along: 𝛁generator)
+
+ // Update discriminator.
+ let realImages = dataset.trainingImages.minibatch(at: i, batchSize: batchSize)
+ let vec2 = sampleVector(size: batchSize)
+ let fakeImages = generator(vec2)
+
+ let 𝛁discriminator = discriminator.gradient { discriminator -> Tensor in
+ let realLogits = discriminator(realImages)
+ let fakeLogits = discriminator(fakeImages)
+ let loss = discriminatorLoss(realLogits: realLogits, fakeLogits: fakeLogits)
+ return loss
+ }
+ optD.update(&discriminator, along: 𝛁discriminator)
+ }
+
+ // Start inference phase.
+ Context.local.learningPhase = .inference
+ let testImage = generator(testVector)
+
+ do {
+ try saveImageGrid(testImage, name: "epoch-\(epoch)-output")
+ } catch {
+ print("Could not save image grid with error: \(error)")
+ }
+
+ let lossG = generatorLoss(fakeLogits: testImage)
+ print("[Epoch: \(epoch)] Loss-G: \(lossG)")
+}
diff --git a/Gym/Blackjack/main.swift b/Gym/Blackjack/main.swift
new file mode 100644
index 00000000000..cdd5ef7c71c
--- /dev/null
+++ b/Gym/Blackjack/main.swift
@@ -0,0 +1,182 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import Python
+import TensorFlow
+
+let gym = Python.import("gym")
+let environment = gym.make("Blackjack-v0")
+
+let iterationCount = 10000
+let learningPhase = iterationCount * 5 / 100
+
+typealias Strategy = Bool
+
+class BlackjackState {
+ var playerSum: Int = 0
+ var dealerCard: Int = 0
+ var useableAce: Int = 0
+
+ init(pythonState: PythonObject) {
+ self.playerSum = Int(pythonState[0]) ?? 0
+ self.dealerCard = Int(pythonState[1]) ?? 0
+ self.useableAce = Int(pythonState[2]) ?? 0
+ }
+}
+
+enum SolverType: CaseIterable {
+ case random, markov, qlearning, normal
+}
+
+class Solver {
+ var Q: [[[[Float]]]] = []
+ var alpha: Float = 0.5
+ let gamma: Float = 0.2
+
+ let playerStateCount = 32 // 21 + 10 + 1 offset
+ let dealerVisibleStateCount = 11 // 10 + 1 offset
+ let aceStateCount = 2 // useable / not bool
+ let playerActionCount = 2 // hit / stay
+
+ init() {
+ Q = Array(repeating: Array(repeating: Array(repeating: Array(repeating: 0.0,
+ count: playerActionCount),
+ count: aceStateCount),
+ count: dealerVisibleStateCount),
+ count: playerStateCount)
+ }
+
+ func updateQLearningStrategy(prior: BlackjackState,
+ action: Int,
+ reward: Int,
+ post: BlackjackState) {
+ let oldQ = Q[prior.playerSum][prior.dealerCard][prior.useableAce][action]
+ let priorQ = (1 - alpha) * oldQ
+
+ let maxReward = max(Q[post.playerSum][post.dealerCard][post.useableAce][0],
+ Q[post.playerSum][post.dealerCard][post.useableAce][1])
+ let postQ = alpha * (Float(reward) + gamma * maxReward)
+
+ Q[prior.playerSum][prior.dealerCard][prior.useableAce][action] += priorQ + postQ
+ }
+
+ func qLearningStrategy(observation: BlackjackState, iteration: Int) -> Strategy {
+ let qLookup = Q[observation.playerSum][observation.dealerCard][observation.useableAce]
+ let stayReward = qLookup[0]
+ let hitReward = qLookup[1]
+
+ if iteration < Int.random(in: 1...learningPhase) {
+ return randomStrategy()
+ } else {
+ // quit learning after initial phase
+ if iteration > learningPhase { alpha = 0.0 }
+ }
+
+ if hitReward == stayReward {
+ return randomStrategy()
+ } else {
+ return hitReward > stayReward
+ }
+ }
+
+ func randomStrategy() -> Strategy {
+ return Strategy.random()
+ }
+
+ func markovStrategy(observation: BlackjackState) -> Strategy {
+ // hit @ 80% probability unless over 18, in which case do the reverse
+ let flip = Float.random(in: 0..<1)
+ let threshHold: Float = 0.8
+
+ if observation.playerSum < 18 {
+ return flip < threshHold
+ } else {
+ return flip > threshHold
+ }
+ }
+
+ func normalStrategyLookup(playerSum: Int) -> String {
+ // see figure 11: https://ieeexplore.ieee.org/document/1299399/
+ switch playerSum {
+ case 10: return "HHHHHSSHHH"
+ case 11: return "HHSSSSSSHH"
+ case 12: return "HSHHHHHHHH"
+ case 13: return "HSSHHHHHHH"
+ case 14: return "HSHHHHHHHH"
+ case 15: return "HSSHHHHHHH"
+ case 16: return "HSSSSSHHHH"
+ case 17: return "HSSSSHHHHH"
+ case 18: return "SSSSSSSSSS"
+ case 19: return "SSSSSSSSSS"
+ case 20: return "SSSSSSSSSS"
+ case 21: return "SSSSSSSSSS"
+ default: return "HHHHHHHHHH"
+ }
+ }
+
+ func normalStrategy(observation: BlackjackState) -> Strategy {
+ if observation.playerSum == 0 {
+ return true
+ }
+ let lookupString = normalStrategyLookup(playerSum: observation.playerSum)
+ return Array(lookupString)[observation.dealerCard - 1] == "H"
+ }
+
+ func strategy(observation: BlackjackState, solver: SolverType, iteration: Int) -> Strategy {
+ switch solver {
+ case .random:
+ return randomStrategy()
+ case .markov:
+ return markovStrategy(observation: observation)
+ case .qlearning:
+ return qLearningStrategy(observation: observation, iteration: iteration)
+ case .normal:
+ return normalStrategy(observation: observation)
+ }
+ }
+}
+
+let learner = Solver()
+
+for solver in SolverType.allCases {
+ var totalReward = 0
+
+ for i in 1...iterationCount {
+ var isDone = false
+ environment.reset()
+
+ while !isDone {
+ let priorState = BlackjackState(pythonState: environment._get_obs())
+ let action: Int = learner.strategy(observation: priorState,
+ solver: solver,
+ iteration: i) ? 1 : 0
+
+ let (pythonPostState, reward, done, _) = environment.step(action).tuple4
+
+ if solver == .qlearning {
+ let postState = BlackjackState(pythonState: pythonPostState)
+ learner.updateQLearningStrategy(prior: priorState,
+ action: action,
+ reward: Int(reward) ?? 0,
+ post: postState)
+ }
+
+ if done == true {
+ totalReward += Int(reward) ?? 0
+ isDone = true
+ }
+ }
+ }
+ print("Solver: \(solver), Total reward: \(totalReward) / \(iterationCount) trials")
+}
diff --git a/Gym/CartPole/main.swift b/Gym/CartPole/main.swift
index e989fe4266b..2825e1400f5 100644
--- a/Gym/CartPole/main.swift
+++ b/Gym/CartPole/main.swift
@@ -48,7 +48,7 @@ struct Net: Layer {
}
@differentiable
- func call(_ input: Input) -> Output {
+ func callAsFunction(_ input: Input) -> Output {
return input.sequenced(through: l1, l2)
}
}
@@ -184,7 +184,7 @@ while true {
return loss
}
}
- optimizer.update(&net.allDifferentiableVariables, along: gradients)
+ optimizer.update(&net, along: gradients)
print("It has episode count \(episodeCount) and mean reward \(meanReward)")
diff --git a/Gym/README.md b/Gym/README.md
index 1b88aca6b3e..97d5bfbf66b 100644
--- a/Gym/README.md
+++ b/Gym/README.md
@@ -10,6 +10,10 @@ This directory contains reinforcement learning algorithms in [OpenAI Gym](https:
> The agent controls the movement of a character in a grid world. Some tiles of the grid are walkable, and others lead to the agent falling into the water. Additionally, the movement direction of the agent is uncertain and only partially depends on the chosen direction. The agent is rewarded for finding a walkable path to a goal tile.
+## [Blackjack](https://gym.openai.com/envs/Blackjack-v0)
+
+> This demonstrates four different approaches to playing the game Blackjack, including a q-learning approach.
+
## Setup
To begin, you'll need the [latest version of Swift for
@@ -26,4 +30,5 @@ To build and run the models, run:
```bash
swift run Gym-CartPole
swift run Gym-FrozenLake
+swift run Gym-Blackjack
```
diff --git a/MNIST/main.swift b/MNIST/main.swift
deleted file mode 100644
index 99953ada6af..00000000000
--- a/MNIST/main.swift
+++ /dev/null
@@ -1,118 +0,0 @@
-// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-import Foundation
-import TensorFlow
-
-/// Reads a file into an array of bytes.
-func readFile(_ path: String) -> [UInt8] {
- let possibleFolders = [".", "MNIST"]
- for folder in possibleFolders {
- let parent = URL(fileURLWithPath: folder)
- let filePath = parent.appendingPathComponent(path)
- guard FileManager.default.fileExists(atPath: filePath.path) else {
- continue
- }
- let data = try! Data(contentsOf: filePath, options: [])
- return [UInt8](data)
- }
- print("File not found: \(path)")
- exit(-1)
-}
-
-/// Reads MNIST images and labels from specified file paths.
-func readMNIST(imagesFile: String, labelsFile: String) -> (images: Tensor,
- labels: Tensor) {
- print("Reading data.")
- let images = readFile(imagesFile).dropFirst(16).map(Float.init)
- let labels = readFile(labelsFile).dropFirst(8).map(Int32.init)
- let rowCount = labels.count
- let imageHeight = 28, imageWidth = 28
-
- print("Constructing data tensors.")
- return (
- images: Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images)
- .transposed(withPermutations: [0, 2, 3, 1]) / 255, // NHWC
- labels: Tensor(labels)
- )
-}
-
-/// A classifier.
-struct Classifier: Layer {
- typealias Input = Tensor
- typealias Output = Tensor
-
- var conv1a = Conv2D(filterShape: (3, 3, 1, 32), activation: relu)
- var conv1b = Conv2D(filterShape: (3, 3, 32, 64), activation: relu)
- var pool1 = MaxPool2D(poolSize: (2, 2), strides: (2, 2))
-
- var dropout1a = Dropout(probability: 0.25)
- var flatten = Flatten()
- var layer1a = Dense(inputSize: 9216, outputSize: 128, activation: relu)
- var dropout1b = Dropout(probability: 0.5)
- var layer1b = Dense(inputSize: 128, outputSize: 10, activation: softmax)
-
- @differentiable
- func call(_ input: Input) -> Output {
- let convolved = input.sequenced(through: conv1a, conv1b, pool1)
- return convolved.sequenced(through: dropout1a, flatten, layer1a, dropout1b, layer1b)
- }
-}
-
-let epochCount = 12
-let batchSize = 100
-
-func minibatch(in x: Tensor, at index: Int) -> Tensor {
- let start = index * batchSize
- return x[start..(oneHotAtIndices: numericLabels, depth: 10)
-
-var classifier = Classifier()
-let optimizer = RMSProp(for: classifier)
-
-print("Beginning training...")
-
-// The training loop.
-for epoch in 1...epochCount {
- var correctGuessCount = 0
- var totalGuessCount = 0
- var totalLoss: Float = 0
- for i in 0 ..< Int(labels.shape[0]) / batchSize {
- let x = minibatch(in: images, at: i)
- let y = minibatch(in: numericLabels, at: i)
- // Compute the gradient with respect to the model.
- let 𝛁model = classifier.gradient { classifier -> Tensor in
- let ŷ = classifier(x)
- let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== y
- correctGuessCount += Int(Tensor(correctPredictions).sum().scalarized())
- totalGuessCount += batchSize
- let loss = softmaxCrossEntropy(logits: ŷ, labels: y)
- totalLoss += loss.scalarized()
- return loss
- }
- // Update the model's differentiable variables along the gradient vector.
- optimizer.update(&classifier.allDifferentiableVariables, along: 𝛁model)
- }
-
- let accuracy = Float(correctGuessCount) / Float(totalGuessCount)
- print("""
- [Epoch \(epoch)] \
- Loss: \(totalLoss), \
- Accuracy: \(correctGuessCount)/\(totalGuessCount) (\(accuracy))
- """)
-}
diff --git a/MNIST/train-images-idx3-ubyte b/MNIST/train-images-idx3-ubyte
deleted file mode 100644
index bbce27659e0..00000000000
Binary files a/MNIST/train-images-idx3-ubyte and /dev/null differ
diff --git a/MNIST/train-labels-idx1-ubyte b/MNIST/train-labels-idx1-ubyte
deleted file mode 100644
index d6b4c5db3b5..00000000000
Binary files a/MNIST/train-labels-idx1-ubyte and /dev/null differ
diff --git a/MiniGo/Models/GoModel.swift b/MiniGo/Models/GoModel.swift
index eb6e0450294..bdd61fb99a2 100644
--- a/MiniGo/Models/GoModel.swift
+++ b/MiniGo/Models/GoModel.swift
@@ -61,7 +61,7 @@ struct ConvBN: Layer {
}
@differentiable
- func call(_ input: Tensor) -> Tensor {
+ func callAsFunction(_ input: Tensor) -> Tensor {
return norm(conv(input))
}
}
@@ -90,7 +90,7 @@ struct ResidualIdentityBlock: Layer {
}
@differentiable
- func call(_ input: Tensor) -> Tensor {
+ func callAsFunction(_ input: Tensor) -> Tensor {
var tmp = relu(layer1(input))
tmp = layer2(tmp)
return relu(tmp + input)
@@ -106,9 +106,9 @@ extension ResidualIdentityBlock: LoadableFromPythonCheckpoint {
// This is needed because we can't conform tuples to protocols
public struct GoModelOutput: Differentiable {
- public let policy: Tensor
- public let value: Tensor
- public let logits: Tensor
+ public var policy: Tensor
+ public var value: Tensor
+ public var logits: Tensor
}
public struct GoModel: Layer {
@@ -158,7 +158,7 @@ public struct GoModel: Layer {
}
@differentiable(wrt: (self, input), vjp: _vjpCall)
- public func call(_ input: Tensor) -> GoModelOutput {
+ public func callAsFunction(_ input: Tensor) -> GoModelOutput {
let batchSize = input.shape[0]
var output = relu(initialConv(input))
@@ -183,12 +183,12 @@ public struct GoModel: Layer {
@usableFromInline
func _vjpCall(_ input: Tensor)
- -> (GoModelOutput, (GoModelOutput.CotangentVector)
- -> (GoModel.CotangentVector, Tensor)) {
+ -> (GoModelOutput, (GoModelOutput.TangentVector)
+ -> (GoModel.TangentVector, Tensor)) {
// TODO(jekbradbury): add a real VJP
// (we're only interested in inference for now and have control flow in our `call(_:)` method)
return (self(input), {
- seed in (GoModel.CotangentVector.zero, Tensor(0))
+ seed in (GoModel.TangentVector.zero, Tensor(0))
})
}
}
diff --git a/MiniGo/Models/PythonCheckpointReader.swift b/MiniGo/Models/PythonCheckpointReader.swift
index ed434e6563a..2fe1a950032 100644
--- a/MiniGo/Models/PythonCheckpointReader.swift
+++ b/MiniGo/Models/PythonCheckpointReader.swift
@@ -28,12 +28,9 @@ public class PythonCheckpointReader {
let countSuffix = layerCounts[layerName] == nil ? "" : "_\(layerCounts[layerName]!)"
let tensorName = layerName + countSuffix + "/" + weightName
// TODO(jekbradbury): support variadic dtype attrs in RawOpsGenerated
- return Tensor(handle: #tfop(
- "RestoreV2",
- StringTensor(path),
- StringTensor([tensorName]),
- StringTensor([""]),
- dtypes$dtype: [Float.tensorFlowDataType]))
+ return Raw.restoreV2(prefix: StringTensor(path),
+ tensorNames: StringTensor([tensorName]),
+ shapeAndSlices: StringTensor([""]))
}
/// Increments a per-layer counter for variable names in the checkpoint file.
diff --git a/MiniGo/README.md b/MiniGo/README.md
index 9ae8938a809..0126a223c71 100644
--- a/MiniGo/README.md
+++ b/MiniGo/README.md
@@ -40,7 +40,7 @@ gsutil cp 'gs://minigo-pub/v15-19x19/models/000939-heron.*' MiniGoCheckpoint/
```sh
# Run inference (self-plays).
cd swift-models
-swift run -Xlinker -ltensorflow -c release MiniGo
+swift run -Xlinker -ltensorflow -c release MiniGoDemo
```
[Swift for TensorFlow]: https://www.tensorflow.org/swift
diff --git a/Models/ImageClassification/LeNet-5.swift b/Models/ImageClassification/LeNet-5.swift
new file mode 100644
index 00000000000..fc8dba008c8
--- /dev/null
+++ b/Models/ImageClassification/LeNet-5.swift
@@ -0,0 +1,42 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import TensorFlow
+
+// Original Paper:
+// "Gradient-Based Learning Applied to Document Recognition"
+// Yann LeCun, Léon Bottou, Yoshua Bengio, and Patrick Haffner
+// http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf
+//
+// Note: this implementation connects all the feature maps in the second convolutional layer.
+// Additionally, ReLU is used instead of sigmoid activations.
+
+public struct LeNet: Layer {
+ public var conv1 = Conv2D(filterShape: (5, 5, 1, 6), padding: .same, activation: relu)
+ public var pool1 = AvgPool2D(poolSize: (2, 2), strides: (2, 2))
+ public var conv2 = Conv2D(filterShape: (5, 5, 6, 16), activation: relu)
+ public var pool2 = AvgPool2D(poolSize: (2, 2), strides: (2, 2))
+ public var flatten = Flatten()
+ public var fc1 = Dense(inputSize: 400, outputSize: 120, activation: relu)
+ public var fc2 = Dense(inputSize: 120, outputSize: 84, activation: relu)
+ public var fc3 = Dense(inputSize: 84, outputSize: 10, activation: softmax)
+
+ public init() {}
+
+ @differentiable
+ public func callAsFunction(_ input: Tensor) -> Tensor {
+ let convolved = input.sequenced(through: conv1, pool1, conv2, pool2)
+ return convolved.sequenced(through: flatten, fc1, fc2, fc3)
+ }
+}
diff --git a/Models/ImageClassification/ResNet50.swift b/Models/ImageClassification/ResNet50.swift
new file mode 100644
index 00000000000..0e79d471334
--- /dev/null
+++ b/Models/ImageClassification/ResNet50.swift
@@ -0,0 +1,336 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import TensorFlow
+
+// Original Paper:
+// "Deep Residual Learning for Image Recognition"
+// Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
+// https://arxiv.org/abs/1512.03385
+// using shortcut layer to connect BasicBlock layers (aka Option (B))
+
+public enum DataKind {
+ case cifar
+ case imagenet
+}
+
+public struct ConvBN: Layer {
+ public var conv: Conv2D
+ public var norm: BatchNorm
+
+ public init(
+ filterShape: (Int, Int, Int, Int),
+ strides: (Int, Int) = (1, 1),
+ padding: Padding = .valid
+ ) {
+ self.conv = Conv2D(filterShape: filterShape, strides: strides, padding: padding)
+ self.norm = BatchNorm(featureCount: filterShape.3)
+ }
+
+ @differentiable
+ public func callAsFunction(_ input: Tensor) -> Tensor {
+ return input.sequenced(through: conv, norm)
+ }
+}
+
+public struct ResidualBasicBlockShortcut: Layer {
+ public var layer1: ConvBN
+ public var layer2: ConvBN
+ public var shortcut: ConvBN
+
+ public init(featureCounts: (Int, Int, Int, Int), kernelSize: Int = 3) {
+ self.layer1 = ConvBN(
+ filterShape: (kernelSize, kernelSize, featureCounts.0, featureCounts.1),
+ strides: (2, 2),
+ padding: .same)
+ self.layer2 = ConvBN(
+ filterShape: (kernelSize, kernelSize, featureCounts.1, featureCounts.2),
+ strides: (1, 1),
+ padding: .same)
+ self.shortcut = ConvBN(
+ filterShape: (1, 1, featureCounts.0, featureCounts.3),
+ strides: (2, 2),
+ padding: .same)
+ }
+
+ @differentiable
+ public func callAsFunction(_ input: Tensor) -> Tensor {
+ return layer2(relu(layer1(input))) + shortcut(input)
+ }
+}
+
+public struct ResidualBasicBlock: Layer {
+ public var layer1: ConvBN
+ public var layer2: ConvBN
+
+ public init(
+ featureCounts: (Int, Int, Int, Int),
+ kernelSize: Int = 3,
+ strides: (Int, Int) = (1, 1)
+ ) {
+ self.layer1 = ConvBN(
+ filterShape: (kernelSize, kernelSize, featureCounts.0, featureCounts.1),
+ strides: strides,
+ padding: .same)
+ self.layer2 = ConvBN(
+ filterShape: (kernelSize, kernelSize, featureCounts.1, featureCounts.3),
+ strides: strides,
+ padding: .same)
+ }
+
+ @differentiable
+ public func callAsFunction(_ input: Tensor) -> Tensor {
+ return layer2(relu(layer1(input)))
+ }
+}
+
+public struct ResidualBasicBlockStack: Layer {
+ public var blocks: [ResidualBasicBlock] = []
+
+ public init(featureCounts: (Int, Int, Int, Int), kernelSize: Int = 3, blockCount: Int) {
+ for _ in 0..) -> Tensor {
+ let blocksReduced = blocks.differentiableReduce(input) { last, layer in
+ layer(last)
+ }
+ return blocksReduced
+ }
+}
+
+public struct ResidualConvBlock: Layer {
+ public var layer1: ConvBN
+ public var layer2: ConvBN
+ public var layer3: ConvBN
+ public var shortcut: ConvBN
+
+ public init(
+ featureCounts: (Int, Int, Int, Int),
+ kernelSize: Int = 3,
+ strides: (Int, Int) = (2, 2)
+ ) {
+ self.layer1 = ConvBN(
+ filterShape: (1, 1, featureCounts.0, featureCounts.1),
+ strides: strides)
+ self.layer2 = ConvBN(
+ filterShape: (kernelSize, kernelSize, featureCounts.1, featureCounts.2),
+ padding: .same)
+ self.layer3 = ConvBN(filterShape: (1, 1, featureCounts.2, featureCounts.3))
+ self.shortcut = ConvBN(
+ filterShape: (1, 1, featureCounts.0, featureCounts.3),
+ strides: strides,
+ padding: .same)
+ }
+
+ @differentiable
+ public func callAsFunction(_ input: Tensor) -> Tensor {
+ let tmp = relu(layer2(relu(layer1(input))))
+ return relu(layer3(tmp) + shortcut(input))
+ }
+}
+
+public struct ResidualIdentityBlock: Layer {
+ public var layer1: ConvBN
+ public var layer2: ConvBN
+ public var layer3: ConvBN
+
+ public init(featureCounts: (Int, Int, Int, Int), kernelSize: Int = 3) {
+ self.layer1 = ConvBN(filterShape: (1, 1, featureCounts.0, featureCounts.1))
+ self.layer2 = ConvBN(
+ filterShape: (kernelSize, kernelSize, featureCounts.1, featureCounts.2),
+ padding: .same)
+ self.layer3 = ConvBN(filterShape: (1, 1, featureCounts.2, featureCounts.3))
+ }
+
+ @differentiable
+ public func callAsFunction(_ input: Tensor) -> Tensor {
+ let tmp = relu(layer2(relu(layer1(input))))
+ return relu(layer3(tmp) + input)
+ }
+}
+
+public struct ResidualIdentityBlockStack: Layer {
+ public var blocks: [ResidualIdentityBlock] = []
+
+ public init(featureCounts: (Int, Int, Int, Int), kernelSize: Int = 3, blockCount: Int) {
+ for _ in 0..) -> Tensor {
+ let blocksReduced = blocks.differentiableReduce(input) { last, layer in
+ layer(last)
+ }
+ return blocksReduced
+ }
+}
+
+public struct ResNetBasic: Layer {
+ public var l1: ConvBN
+ public var maxPool: MaxPool2D
+
+ public var l2a = ResidualBasicBlock(featureCounts: (64, 64, 64, 64))
+ public var l2b: ResidualBasicBlockStack
+
+ public var l3a = ResidualBasicBlockShortcut(featureCounts: (64, 128, 128, 128))
+ public var l3b: ResidualBasicBlockStack
+
+ public var l4a = ResidualBasicBlockShortcut(featureCounts: (128, 256, 256, 256))
+ public var l4b: ResidualBasicBlockStack
+
+ public var l5a = ResidualBasicBlockShortcut(featureCounts: (256, 512, 512, 512))
+ public var l5b: ResidualBasicBlockStack
+
+ public var avgPool: AvgPool2D
+ public var flatten = Flatten()
+ public var classifier: Dense
+
+ public init(dataKind: DataKind, layerBlockCounts: (Int, Int, Int, Int)) {
+ switch dataKind {
+ case .imagenet:
+ l1 = ConvBN(filterShape: (7, 7, 3, 64), strides: (2, 2), padding: .same)
+ maxPool = MaxPool2D(poolSize: (3, 3), strides: (2, 2))
+ avgPool = AvgPool2D(poolSize: (7, 7), strides: (7, 7))
+ classifier = Dense(inputSize: 512, outputSize: 1000)
+ case .cifar:
+ l1 = ConvBN(filterShape: (3, 3, 3, 64), padding: .same)
+ maxPool = MaxPool2D(poolSize: (1, 1), strides: (1, 1)) // no-op
+ avgPool = AvgPool2D(poolSize: (4, 4), strides: (4, 4))
+ classifier = Dense(inputSize: 512, outputSize: 10)
+ }
+
+ l2b = ResidualBasicBlockStack(
+ featureCounts: (64, 64, 64, 64),
+ blockCount: layerBlockCounts.0)
+ l3b = ResidualBasicBlockStack(
+ featureCounts: (128, 128, 128, 128),
+ blockCount: layerBlockCounts.1)
+ l4b = ResidualBasicBlockStack(
+ featureCounts: (256, 256, 256, 256),
+ blockCount: layerBlockCounts.2)
+ l5b = ResidualBasicBlockStack(
+ featureCounts: (512, 512, 512, 512),
+ blockCount: layerBlockCounts.3)
+ }
+
+ @differentiable
+ public func callAsFunction(_ input: Tensor) -> Tensor {
+ let inputLayer = maxPool(relu(l1(input)))
+ let level2 = inputLayer.sequenced(through: l2a, l2b)
+ let level3 = level2.sequenced(through: l3a, l3b)
+ let level4 = level3.sequenced(through: l4a, l4b)
+ let level5 = level4.sequenced(through: l5a, l5b)
+ return level5.sequenced(through: avgPool, flatten, classifier)
+ }
+}
+
+extension ResNetBasic {
+ public enum Kind {
+ case resNet18
+ case resNet34
+ }
+
+ public init(inputKind: Kind, dataKind: DataKind) {
+ switch inputKind {
+ case .resNet18:
+ self.init(dataKind: dataKind, layerBlockCounts: (2, 2, 2, 2))
+ case .resNet34:
+ self.init(dataKind: dataKind, layerBlockCounts: (3, 4, 6, 3))
+ }
+ }
+}
+
+public struct ResNet: Layer {
+ public var l1: ConvBN
+ public var maxPool: MaxPool2D
+
+ public var l2a = ResidualConvBlock(featureCounts: (64, 64, 64, 256), strides: (1, 1))
+ public var l2b: ResidualIdentityBlockStack
+
+ public var l3a = ResidualConvBlock(featureCounts: (256, 128, 128, 512))
+ public var l3b: ResidualIdentityBlockStack
+
+ public var l4a = ResidualConvBlock(featureCounts: (512, 256, 256, 1024))
+ public var l4b: ResidualIdentityBlockStack
+
+ public var l5a = ResidualConvBlock(featureCounts: (1024, 512, 512, 2048))
+ public var l5b: ResidualIdentityBlockStack
+
+ public var avgPool: AvgPool2D
+ public var flatten = Flatten()
+ public var classifier: Dense
+
+ public init(dataKind: DataKind, layerBlockCounts: (Int, Int, Int, Int)) {
+ switch dataKind {
+ case .imagenet:
+ l1 = ConvBN(filterShape: (7, 7, 3, 64), strides: (2, 2), padding: .same)
+ maxPool = MaxPool2D(poolSize: (3, 3), strides: (2, 2))
+ avgPool = AvgPool2D(poolSize: (7, 7), strides: (7, 7))
+ classifier = Dense(inputSize: 2048, outputSize: 1000)
+ case .cifar:
+ l1 = ConvBN(filterShape: (3, 3, 3, 64), padding: .same)
+ maxPool = MaxPool2D(poolSize: (1, 1), strides: (1, 1)) // no-op
+ avgPool = AvgPool2D(poolSize: (4, 4), strides: (4, 4))
+ classifier = Dense(inputSize: 2048, outputSize: 10)
+ }
+
+ l2b = ResidualIdentityBlockStack(
+ featureCounts: (256, 64, 64, 256),
+ blockCount: layerBlockCounts.0)
+ l3b = ResidualIdentityBlockStack(
+ featureCounts: (512, 128, 128, 512),
+ blockCount: layerBlockCounts.1)
+ l4b = ResidualIdentityBlockStack(
+ featureCounts: (1024, 256, 256, 1024),
+ blockCount: layerBlockCounts.2)
+ l5b = ResidualIdentityBlockStack(
+ featureCounts: (2048, 512, 512, 2048),
+ blockCount: layerBlockCounts.3)
+ }
+
+ @differentiable
+ public func callAsFunction(_ input: Tensor) -> Tensor {
+ let inputLayer = maxPool(relu(l1(input)))
+ let level2 = inputLayer.sequenced(through: l2a, l2b)
+ let level3 = level2.sequenced(through: l3a, l3b)
+ let level4 = level3.sequenced(through: l4a, l4b)
+ let level5 = level4.sequenced(through: l5a, l5b)
+ return level5.sequenced(through: avgPool, flatten, classifier)
+ }
+}
+
+extension ResNet {
+ public enum Kind {
+ case resNet50
+ case resNet101
+ case resNet152
+ }
+
+ public init(inputKind: Kind, dataKind: DataKind) {
+ switch inputKind {
+ case .resNet50:
+ self.init(dataKind: dataKind, layerBlockCounts: (3, 4, 6, 3))
+ case .resNet101:
+ self.init(dataKind: dataKind, layerBlockCounts: (3, 4, 23, 3))
+ case .resNet152:
+ self.init(dataKind: dataKind, layerBlockCounts: (3, 8, 36, 3))
+ }
+ }
+}
diff --git a/ResNet/ResNetV2.swift b/Models/ImageClassification/ResNetV2.swift
similarity index 56%
rename from ResNet/ResNetV2.swift
rename to Models/ImageClassification/ResNetV2.swift
index b8545e8ee48..ecd353c6cd2 100644
--- a/ResNet/ResNetV2.swift
+++ b/Models/ImageClassification/ResNetV2.swift
@@ -21,14 +21,11 @@ import TensorFlow
// https://arxiv.org/abs/1603.05027
// https://github.com/KaimingHe/resnet-1k-layers/
-struct Conv2DBatchNorm: Layer {
- typealias Input = Tensor
- typealias Output = Tensor
+public struct Conv2DBatchNorm: Layer {
+ public var conv: Conv2D
+ public var norm: BatchNorm
- var conv: Conv2D
- var norm: BatchNorm
-
- init(
+ public init(
filterShape: (Int, Int, Int, Int),
strides: (Int, Int) = (1, 1),
padding: Padding = .valid
@@ -38,19 +35,16 @@ struct Conv2DBatchNorm: Layer {
}
@differentiable
- func call(_ input: Input) -> Output {
+ public func callAsFunction(_ input: Tensor) -> Tensor {
return input.sequenced(through: conv, norm)
}
}
-struct BatchNormConv2D: Layer {
- typealias Input = Tensor
- typealias Output = Tensor
-
- var norm: BatchNorm
- var conv: Conv2D
+public struct BatchNormConv2D: Layer {
+ public var norm: BatchNorm
+ public var conv: Conv2D
- init(
+ public init(
filterShape: (Int, Int, Int, Int),
strides: (Int, Int) = (1, 1),
padding: Padding = .valid
@@ -60,19 +54,16 @@ struct BatchNormConv2D: Layer {
}
@differentiable
- func call(_ input: Input) -> Output {
+ public func callAsFunction(_ input: Tensor) -> Tensor {
return conv(relu(norm(input)))
}
}
-struct PreActivatedResidualBasicBlock: Layer {
- typealias Input = Tensor
- typealias Output = Tensor
-
- var layer1: BatchNormConv2D
- var layer2: BatchNormConv2D
+public struct PreActivatedResidualBasicBlock: Layer {
+ public var layer1: BatchNormConv2D
+ public var layer2: BatchNormConv2D
- init(
+ public init(
featureCounts: (Int, Int, Int, Int),
kernelSize: Int = 3,
strides: (Int, Int) = (1, 1)
@@ -88,20 +79,17 @@ struct PreActivatedResidualBasicBlock: Layer {
}
@differentiable
- func call(_ input: Input) -> Output {
+ public func callAsFunction(_ input: Tensor) -> Tensor {
return input.sequenced(through: layer1, layer2)
}
}
-struct PreActivatedResidualBasicBlockShortcut: Layer {
- typealias Input = Tensor
- typealias Output = Tensor
+public struct PreActivatedResidualBasicBlockShortcut: Layer {
+ public var layer1: BatchNormConv2D
+ public var layer2: BatchNormConv2D
+ public var shortcut: Conv2D
- var layer1: BatchNormConv2D
- var layer2: BatchNormConv2D
- var shortcut: Conv2D
-
- init(featureCounts: (Int, Int, Int, Int), kernelSize: Int = 3) {
+ public init(featureCounts: (Int, Int, Int, Int), kernelSize: Int = 3) {
self.layer1 = BatchNormConv2D(
filterShape: (kernelSize, kernelSize, featureCounts.0, featureCounts.1),
strides: (2, 2),
@@ -117,36 +105,33 @@ struct PreActivatedResidualBasicBlockShortcut: Layer {
}
@differentiable
- func call(_ input: Input) -> Output {
+ public func callAsFunction(_ input: Tensor) -> Tensor {
return input.sequenced(through: layer1, layer2) + shortcut(input)
}
}
-struct PreActivatedResNet18: Layer {
- typealias Input = Tensor
- typealias Output = Tensor