diff --git a/GAN/README.md b/GAN/README.md new file mode 100644 index 00000000000..d9bbe8fca8a --- /dev/null +++ b/GAN/README.md @@ -0,0 +1,23 @@ +# Simple GAN + +After Epoch 1: +

+ +

+ +After Epoch 10: +

+ +

+ +## Setup + +To begin, you'll need the [latest version of Swift for +TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md) +installed. Make sure you've added the correct version of `swift` to your path. + +To train the model, run: + +```sh +swift run GAN +``` \ No newline at end of file diff --git a/GAN/Resources/train-images-idx3-ubyte b/GAN/Resources/train-images-idx3-ubyte new file mode 100644 index 00000000000..bbce27659e0 Binary files /dev/null and b/GAN/Resources/train-images-idx3-ubyte differ diff --git a/GAN/images/epoch-1-output.png b/GAN/images/epoch-1-output.png new file mode 100644 index 00000000000..b7fd99b7295 Binary files /dev/null and b/GAN/images/epoch-1-output.png differ diff --git a/GAN/images/epoch-10-output.png b/GAN/images/epoch-10-output.png new file mode 100644 index 00000000000..66669272db4 Binary files /dev/null and b/GAN/images/epoch-10-output.png differ diff --git a/GAN/main.swift b/GAN/main.swift new file mode 100644 index 00000000000..7389375c674 --- /dev/null +++ b/GAN/main.swift @@ -0,0 +1,210 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation +import TensorFlow +import Python + +// Import Python modules. +let matplotlib = Python.import("matplotlib") +let np = Python.import("numpy") +let plt = Python.import("matplotlib.pyplot") + +// Turn off using display on server / Linux. +matplotlib.use("Agg") + +let epochCount = 10 +let batchSize = 32 +let outputFolder = "./output/" +let imageHeight = 28, imageWidth = 28 +let imageSize = imageHeight * imageWidth +let latentSize = 64 + +func plotImage(_ image: Tensor, name: String) { + // Create figure. + let ax = plt.gca() + let array = np.array([image.scalars]) + let pixels = array.reshape(image.shape) + if !FileManager.default.fileExists(atPath: outputFolder) { + try! FileManager.default.createDirectory( + atPath: outputFolder, + withIntermediateDirectories: false, + attributes: nil) + } + ax.imshow(pixels, cmap: "gray") + plt.savefig("\(outputFolder)\(name).png", dpi: 300) + plt.close() +} + +/// Reads a file into an array of bytes. +func readFile(_ filename: String) -> [UInt8] { + let possibleFolders = [".", "Resources", "GAN/Resources"] + for folder in possibleFolders { + let parent = URL(fileURLWithPath: folder) + let filePath = parent.appendingPathComponent(filename).path + guard FileManager.default.fileExists(atPath: filePath) else { + continue + } + let d = Python.open(filePath, "rb").read() + return Array(numpy: np.frombuffer(d, dtype: np.uint8))! + } + print("Failed to find file with name \(filename) in the following folders: \(possibleFolders).") + exit(-1) +} + +/// Reads MNIST images from specified file path. +func readMNIST(imagesFile: String) -> Tensor { + print("Reading data.") + let images = readFile(imagesFile).dropFirst(16).map { Float($0) } + let rowCount = images.count / imageSize + print("Constructing data tensors.") + return Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images) / 255.0 * 2 - 1 +} + +// Models + +struct Generator: Layer { + var dense1 = Dense(inputSize: latentSize, outputSize: latentSize * 2, + activation: { leakyRelu($0) }) + var dense2 = Dense(inputSize: latentSize * 2, outputSize: latentSize * 4, + activation: { leakyRelu($0) }) + var dense3 = Dense(inputSize: latentSize * 4, outputSize: latentSize * 8, + activation: { leakyRelu($0) }) + var dense4 = Dense(inputSize: latentSize * 8, outputSize: imageSize, + activation: tanh) + + var batchnorm1 = BatchNorm(featureCount: latentSize * 2) + var batchnorm2 = BatchNorm(featureCount: latentSize * 4) + var batchnorm3 = BatchNorm(featureCount: latentSize * 8) + + @differentiable + func callAsFunction(_ input: Tensor) -> Tensor { + let x1 = batchnorm1(dense1(input)) + let x2 = batchnorm2(dense2(x1)) + let x3 = batchnorm3(dense3(x2)) + return dense4(x3) + } +} + +struct Discriminator: Layer { + var dense1 = Dense(inputSize: imageSize, outputSize: 256, + activation: { leakyRelu($0) }) + var dense2 = Dense(inputSize: 256, outputSize: 64, + activation: { leakyRelu($0) }) + var dense3 = Dense(inputSize: 64, outputSize: 16, + activation: { leakyRelu($0) }) + var dense4 = Dense(inputSize: 16, outputSize: 1, + activation: identity) + + @differentiable + func callAsFunction(_ input: Tensor) -> Tensor { + input.sequenced(through: dense1, dense2, dense3, dense4) + } +} + +// Loss functions + +@differentiable +func generatorLoss(fakeLogits: Tensor) -> Tensor { + sigmoidCrossEntropy(logits: fakeLogits, + labels: Tensor(ones: fakeLogits.shape)) +} + +@differentiable +func discriminatorLoss(realLogits: Tensor, fakeLogits: Tensor) -> Tensor { + let realLoss = sigmoidCrossEntropy(logits: realLogits, + labels: Tensor(ones: realLogits.shape)) + let fakeLoss = sigmoidCrossEntropy(logits: fakeLogits, + labels: Tensor(zeros: fakeLogits.shape)) + return realLoss + fakeLoss +} + +/// Returns `size` samples of noise vector. +func sampleVector(size: Int) -> Tensor { + Tensor(randomNormal: [size, latentSize]) +} + +// MNIST data logic + +func minibatch(in x: Tensor, at index: Int) -> Tensor { + let start = index * batchSize + return x[start.., name: String) { + var gridImage = testImage.reshaped(to: [testImageGridSize, testImageGridSize, + imageHeight, imageWidth]) + // Add padding. + gridImage = gridImage.padded(forSizes: [(0, 0), (0, 0), (1, 1), (1, 1)], with: 1) + // Transpose to create single image. + gridImage = gridImage.transposed(withPermutations: [0, 2, 1, 3]) + gridImage = gridImage.reshaped(to: [(imageHeight + 2) * testImageGridSize, + (imageWidth + 2) * testImageGridSize]) + // Convert [-1, 1] range to [0, 1] range. + gridImage = (gridImage + 1) / 2 + plotImage(gridImage, name: name) +} + +print("Start training...") + +// Start training loop. +for epoch in 1...epochCount { + // Start training phase. + Context.local.learningPhase = .training + for i in 0 ..< Int(images.shape[0]) / batchSize { + // Perform alternative update. + // Update generator. + let vec1 = sampleVector(size: batchSize) + + let 𝛁generator = generator.gradient { generator -> Tensor in + let fakeImages = generator(vec1) + let fakeLogits = discriminator(fakeImages) + let loss = generatorLoss(fakeLogits: fakeLogits) + return loss + } + optG.update(&generator.allDifferentiableVariables, along: 𝛁generator) + + // Update discriminator. + let realImages = minibatch(in: images, at: i) + let vec2 = sampleVector(size: batchSize) + let fakeImages = generator(vec2) + + let 𝛁discriminator = discriminator.gradient { discriminator -> Tensor in + let realLogits = discriminator(realImages) + let fakeLogits = discriminator(fakeImages) + let loss = discriminatorLoss(realLogits: realLogits, fakeLogits: fakeLogits) + return loss + } + optD.update(&discriminator.allDifferentiableVariables, along: 𝛁discriminator) + } + + // Start inference phase. + Context.local.learningPhase = .inference + let testImage = generator(testVector) + plotTestImage(testImage, name: "epoch-\(epoch)-output") + + let lossG = generatorLoss(fakeLogits: testImage) + print("[Epoch: \(epoch)] Loss-G: \(lossG)") +} diff --git a/Package.swift b/Package.swift index 289c1086abc..2f13a6b394d 100644 --- a/Package.swift +++ b/Package.swift @@ -14,6 +14,7 @@ let package = Package( .executable(name: "ResNet", targets: ["ResNet"]), .executable(name: "MiniGoDemo", targets: ["MiniGoDemo"]), .library(name: "MiniGo", targets: ["MiniGo"]), + .executable(name: "GAN", targets: ["GAN"]), ], targets: [ .target(name: "Autoencoder", path: "Autoencoder"), @@ -29,5 +30,6 @@ let package = Package( .testTarget(name: "MiniGoTests", dependencies: ["MiniGo"]), .target(name: "ResNet", path: "ResNet"), .target(name: "Transformer", path: "Transformer"), + .target(name: "GAN", path: "GAN"), ] )