diff --git a/GAN/README.md b/GAN/README.md
new file mode 100644
index 00000000000..d9bbe8fca8a
--- /dev/null
+++ b/GAN/README.md
@@ -0,0 +1,23 @@
+# Simple GAN
+
+After Epoch 1:
+
+
+
+
+After Epoch 10:
+
+
+
+
+## Setup
+
+To begin, you'll need the [latest version of Swift for
+TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
+installed. Make sure you've added the correct version of `swift` to your path.
+
+To train the model, run:
+
+```sh
+swift run GAN
+```
\ No newline at end of file
diff --git a/GAN/Resources/train-images-idx3-ubyte b/GAN/Resources/train-images-idx3-ubyte
new file mode 100644
index 00000000000..bbce27659e0
Binary files /dev/null and b/GAN/Resources/train-images-idx3-ubyte differ
diff --git a/GAN/images/epoch-1-output.png b/GAN/images/epoch-1-output.png
new file mode 100644
index 00000000000..b7fd99b7295
Binary files /dev/null and b/GAN/images/epoch-1-output.png differ
diff --git a/GAN/images/epoch-10-output.png b/GAN/images/epoch-10-output.png
new file mode 100644
index 00000000000..66669272db4
Binary files /dev/null and b/GAN/images/epoch-10-output.png differ
diff --git a/GAN/main.swift b/GAN/main.swift
new file mode 100644
index 00000000000..7389375c674
--- /dev/null
+++ b/GAN/main.swift
@@ -0,0 +1,210 @@
+// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+import Foundation
+import TensorFlow
+import Python
+
+// Import Python modules.
+let matplotlib = Python.import("matplotlib")
+let np = Python.import("numpy")
+let plt = Python.import("matplotlib.pyplot")
+
+// Turn off using display on server / Linux.
+matplotlib.use("Agg")
+
+let epochCount = 10
+let batchSize = 32
+let outputFolder = "./output/"
+let imageHeight = 28, imageWidth = 28
+let imageSize = imageHeight * imageWidth
+let latentSize = 64
+
+func plotImage(_ image: Tensor, name: String) {
+ // Create figure.
+ let ax = plt.gca()
+ let array = np.array([image.scalars])
+ let pixels = array.reshape(image.shape)
+ if !FileManager.default.fileExists(atPath: outputFolder) {
+ try! FileManager.default.createDirectory(
+ atPath: outputFolder,
+ withIntermediateDirectories: false,
+ attributes: nil)
+ }
+ ax.imshow(pixels, cmap: "gray")
+ plt.savefig("\(outputFolder)\(name).png", dpi: 300)
+ plt.close()
+}
+
+/// Reads a file into an array of bytes.
+func readFile(_ filename: String) -> [UInt8] {
+ let possibleFolders = [".", "Resources", "GAN/Resources"]
+ for folder in possibleFolders {
+ let parent = URL(fileURLWithPath: folder)
+ let filePath = parent.appendingPathComponent(filename).path
+ guard FileManager.default.fileExists(atPath: filePath) else {
+ continue
+ }
+ let d = Python.open(filePath, "rb").read()
+ return Array(numpy: np.frombuffer(d, dtype: np.uint8))!
+ }
+ print("Failed to find file with name \(filename) in the following folders: \(possibleFolders).")
+ exit(-1)
+}
+
+/// Reads MNIST images from specified file path.
+func readMNIST(imagesFile: String) -> Tensor {
+ print("Reading data.")
+ let images = readFile(imagesFile).dropFirst(16).map { Float($0) }
+ let rowCount = images.count / imageSize
+ print("Constructing data tensors.")
+ return Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images) / 255.0 * 2 - 1
+}
+
+// Models
+
+struct Generator: Layer {
+ var dense1 = Dense(inputSize: latentSize, outputSize: latentSize * 2,
+ activation: { leakyRelu($0) })
+ var dense2 = Dense(inputSize: latentSize * 2, outputSize: latentSize * 4,
+ activation: { leakyRelu($0) })
+ var dense3 = Dense(inputSize: latentSize * 4, outputSize: latentSize * 8,
+ activation: { leakyRelu($0) })
+ var dense4 = Dense(inputSize: latentSize * 8, outputSize: imageSize,
+ activation: tanh)
+
+ var batchnorm1 = BatchNorm(featureCount: latentSize * 2)
+ var batchnorm2 = BatchNorm(featureCount: latentSize * 4)
+ var batchnorm3 = BatchNorm(featureCount: latentSize * 8)
+
+ @differentiable
+ func callAsFunction(_ input: Tensor) -> Tensor {
+ let x1 = batchnorm1(dense1(input))
+ let x2 = batchnorm2(dense2(x1))
+ let x3 = batchnorm3(dense3(x2))
+ return dense4(x3)
+ }
+}
+
+struct Discriminator: Layer {
+ var dense1 = Dense(inputSize: imageSize, outputSize: 256,
+ activation: { leakyRelu($0) })
+ var dense2 = Dense(inputSize: 256, outputSize: 64,
+ activation: { leakyRelu($0) })
+ var dense3 = Dense(inputSize: 64, outputSize: 16,
+ activation: { leakyRelu($0) })
+ var dense4 = Dense(inputSize: 16, outputSize: 1,
+ activation: identity)
+
+ @differentiable
+ func callAsFunction(_ input: Tensor) -> Tensor {
+ input.sequenced(through: dense1, dense2, dense3, dense4)
+ }
+}
+
+// Loss functions
+
+@differentiable
+func generatorLoss(fakeLogits: Tensor) -> Tensor {
+ sigmoidCrossEntropy(logits: fakeLogits,
+ labels: Tensor(ones: fakeLogits.shape))
+}
+
+@differentiable
+func discriminatorLoss(realLogits: Tensor, fakeLogits: Tensor) -> Tensor {
+ let realLoss = sigmoidCrossEntropy(logits: realLogits,
+ labels: Tensor(ones: realLogits.shape))
+ let fakeLoss = sigmoidCrossEntropy(logits: fakeLogits,
+ labels: Tensor(zeros: fakeLogits.shape))
+ return realLoss + fakeLoss
+}
+
+/// Returns `size` samples of noise vector.
+func sampleVector(size: Int) -> Tensor {
+ Tensor(randomNormal: [size, latentSize])
+}
+
+// MNIST data logic
+
+func minibatch(in x: Tensor, at index: Int) -> Tensor {
+ let start = index * batchSize
+ return x[start.., name: String) {
+ var gridImage = testImage.reshaped(to: [testImageGridSize, testImageGridSize,
+ imageHeight, imageWidth])
+ // Add padding.
+ gridImage = gridImage.padded(forSizes: [(0, 0), (0, 0), (1, 1), (1, 1)], with: 1)
+ // Transpose to create single image.
+ gridImage = gridImage.transposed(withPermutations: [0, 2, 1, 3])
+ gridImage = gridImage.reshaped(to: [(imageHeight + 2) * testImageGridSize,
+ (imageWidth + 2) * testImageGridSize])
+ // Convert [-1, 1] range to [0, 1] range.
+ gridImage = (gridImage + 1) / 2
+ plotImage(gridImage, name: name)
+}
+
+print("Start training...")
+
+// Start training loop.
+for epoch in 1...epochCount {
+ // Start training phase.
+ Context.local.learningPhase = .training
+ for i in 0 ..< Int(images.shape[0]) / batchSize {
+ // Perform alternative update.
+ // Update generator.
+ let vec1 = sampleVector(size: batchSize)
+
+ let 𝛁generator = generator.gradient { generator -> Tensor in
+ let fakeImages = generator(vec1)
+ let fakeLogits = discriminator(fakeImages)
+ let loss = generatorLoss(fakeLogits: fakeLogits)
+ return loss
+ }
+ optG.update(&generator.allDifferentiableVariables, along: 𝛁generator)
+
+ // Update discriminator.
+ let realImages = minibatch(in: images, at: i)
+ let vec2 = sampleVector(size: batchSize)
+ let fakeImages = generator(vec2)
+
+ let 𝛁discriminator = discriminator.gradient { discriminator -> Tensor in
+ let realLogits = discriminator(realImages)
+ let fakeLogits = discriminator(fakeImages)
+ let loss = discriminatorLoss(realLogits: realLogits, fakeLogits: fakeLogits)
+ return loss
+ }
+ optD.update(&discriminator.allDifferentiableVariables, along: 𝛁discriminator)
+ }
+
+ // Start inference phase.
+ Context.local.learningPhase = .inference
+ let testImage = generator(testVector)
+ plotTestImage(testImage, name: "epoch-\(epoch)-output")
+
+ let lossG = generatorLoss(fakeLogits: testImage)
+ print("[Epoch: \(epoch)] Loss-G: \(lossG)")
+}
diff --git a/Package.swift b/Package.swift
index 289c1086abc..2f13a6b394d 100644
--- a/Package.swift
+++ b/Package.swift
@@ -14,6 +14,7 @@ let package = Package(
.executable(name: "ResNet", targets: ["ResNet"]),
.executable(name: "MiniGoDemo", targets: ["MiniGoDemo"]),
.library(name: "MiniGo", targets: ["MiniGo"]),
+ .executable(name: "GAN", targets: ["GAN"]),
],
targets: [
.target(name: "Autoencoder", path: "Autoencoder"),
@@ -29,5 +30,6 @@ let package = Package(
.testTarget(name: "MiniGoTests", dependencies: ["MiniGo"]),
.target(name: "ResNet", path: "ResNet"),
.target(name: "Transformer", path: "Transformer"),
+ .target(name: "GAN", path: "GAN"),
]
)