Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions GAN/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Simple GAN

### After Epoch 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
### After Epoch 1
After Epoch 1:

<p align="center">
<img src="images/epoch-1-output.png" height="270" width="360">
</p>

### After Epoch 10
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
### After Epoch 10
After Epoch 10:

<p align="center">
<img src="images/epoch-10-output.png" height="270" width="360">
</p>

## Setup

To begin, you'll need the [latest version of Swift for
TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
installed. Make sure you've added the correct version of `swift` to your path.

To train the model, run:

```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
```
```sh

swift run GAN
```
If you using brew to install python2 and modules, change the path:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Not sure what you mean exactly by "modules".
  • I feel lines 24-30 are unnecessary since they are just describing an installation step that's required for the entire models repository.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It came from Autoencoder/README.md.
I just copied it and didn't consider much.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah ok, we should unify the documentation in all models. @BradLarson how about having a standard README template?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxwei - We definitely will need to rework the READMEs across the examples. Varying levels of information is provided in each, as well as the language used for describing the models. A template for examples would be much appreciated, as well as a good central listing of them. The original Autoencoder README was a little rough (and I think the Python information was confusing and missing the need for matplotlib, etc.), thus the issues with this one derived from it.

- remove brew path '/usr/local/bin'
- add TensorFlow swift Toolchain /Library/Developer/Toolchains/swift-latest/usr/bin

```
export PATH=/Library/Developer/Toolchains/swift-latest/usr/bin:/usr/bin:/bin:/usr/sbin:/sbin:"${PATH}"
```
Binary file added GAN/Resources/train-images-idx3-ubyte
Binary file not shown.
Binary file added GAN/Resources/train-labels-idx1-ubyte
Binary file not shown.
Binary file added GAN/images/epoch-1-output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added GAN/images/epoch-10-output.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
214 changes: 214 additions & 0 deletions GAN/main.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation
import TensorFlow
import Python

// Import Python modules
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Import Python modules
// Import Python modules.

let matplotlib = Python.import("matplotlib")
let np = Python.import("numpy")
let plt = Python.import("matplotlib.pyplot")

// Turn off using display on server / linux
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Turn off using display on server / linux
// Turn off using display on server / Linux.

matplotlib.use("Agg")

// Some globals
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Some globals

This comment isn't providing much useful information as the code is self-explanatorily "some globals".

let epochCount = 10
let batchSize = 32
let outputFolder = "./output/"
let imageHeight = 28, imageWidth = 28
let imageDim = imageHeight*imageWidth
let latentDim = 64

func plot(image: Tensor<Float>, name: String) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the first argument label happens to be the object of the overall verb phrase, append it to the base name and omit the argument label.

Suggested change
func plot(image: Tensor<Float>, name: String) {
func plotImage(_ image: Tensor<Float>, name: String) {

// Create figure
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Create figure
// Create figure.

End sentence comments with a period.

let ax = plt.gca()
let array = np.array([image.scalars])
let pixels = array.reshape(image.shape)
if !FileManager.default.fileExists(atPath: outputFolder) {
try! FileManager.default.createDirectory(atPath: outputFolder,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We aren't using Objective-C style formatting for function call arguments. Could you reformat this like the following?

Suggested change
try! FileManager.default.createDirectory(atPath: outputFolder,
try! FileManager.default.createDirectory(
atPath: outputFolder,
withIntermediateDirectories: false,
attributes: nil)

withIntermediateDirectories: false,
attributes: nil)
}
ax.imshow(pixels, cmap: "gray")
plt.savefig("\(outputFolder)\(name).png", dpi: 300)
plt.close()
}

/// Reads a file into an array of bytes.
func readFile(_ filename: String) -> [UInt8] {
let possibleFolders = [".", "Resources", "GAN/Resources"]
for folder in possibleFolders {
let parent = URL(fileURLWithPath: folder)
let filePath = parent.appendingPathComponent(filename).path
guard FileManager.default.fileExists(atPath: filePath) else {
continue
}
let d = Python.open(filePath, "rb").read()
return Array(numpy: np.frombuffer(d, dtype: np.uint8))!
}
print("Failed to find file with name \(filename) in the following folders: \(possibleFolders).")
exit(-1)
}

/// Reads MNIST images and labels from specified file paths.
func readMNIST(imagesFile: String, labelsFile: String) -> (images: Tensor<Float>,
labels: Tensor<Int32>) {
print("Reading data.")
let images = readFile(imagesFile).dropFirst(16).map { Float($0) }
let labels = readFile(labelsFile).dropFirst(8).map { Int32($0) }
let rowCount = labels.count

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove redundant empty line.

print("Constructing data tensors.")
return (
images: Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images) / 255.0 * 2 - 1,
labels: Tensor(labels)
)
}

// Models
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an empty line, since this is describing both Generator and Discriminator below.

Suggested change
// Models
// Models

struct Generator: Layer {
var dense1 = Dense<Float>(inputSize: latentDim, outputSize: latentDim*2, activation: { leakyRelu($0) })
var dense2 = Dense<Float>(inputSize: latentDim*2, outputSize: latentDim*4, activation: { leakyRelu($0) })
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make all lines fit within 100 columns?

var dense3 = Dense<Float>(inputSize: latentDim*4, outputSize: latentDim*8, activation: { leakyRelu($0) })
var dense4 = Dense<Float>(inputSize: latentDim*8, outputSize: imageDim, activation: tanh)

var batchnorm1 = BatchNorm<Float>(featureCount: latentDim*2)
var batchnorm2 = BatchNorm<Float>(featureCount: latentDim*4)
var batchnorm3 = BatchNorm<Float>(featureCount: latentDim*8)

@differentiable
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
let x1 = batchnorm1(dense1(input))
let x2 = batchnorm2(dense2(x1))
let x3 = batchnorm3(dense3(x2))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

return dense4(x3)
}
}

struct Discriminator: Layer {
var dense1 = Dense<Float>(inputSize: imageDim, outputSize: 256, activation: { leakyRelu($0) })
var dense2 = Dense<Float>(inputSize: 256, outputSize: 64, activation: { leakyRelu($0) })
var dense3 = Dense<Float>(inputSize: 64, outputSize: 16, activation: { leakyRelu($0) })
var dense4 = Dense<Float>(inputSize: 16, outputSize: 1, activation: identity)

@differentiable
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
input.sequenced(through: dense1, dense2, dense3, dense4)
}
}

// Loss functions
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Loss functions
// Loss functions

@differentiable
func generatorLossFunc(fakeLogits: Tensor<Float>) -> Tensor<Float> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
func generatorLossFunc(fakeLogits: Tensor<Float>) -> Tensor<Float> {
func generatorLoss(fakeLogits: Tensor<Float>) -> Tensor<Float> {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to Swift API Design Guidelines, functions without side effects should read like nouns.

sigmoidCrossEntropy(logits: fakeLogits,
labels: Tensor(ones: [fakeLogits.shape[0], 1]))
}

@differentiable
func discriminatorLossFunc(realLogits: Tensor<Float>, fakeLogits: Tensor<Float>) -> Tensor<Float> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
func discriminatorLossFunc(realLogits: Tensor<Float>, fakeLogits: Tensor<Float>) -> Tensor<Float> {
func discriminatorLoss(realLogits: Tensor<Float>, fakeLogits: Tensor<Float>) -> Tensor<Float> {

let realLoss = sigmoidCrossEntropy(logits: realLogits,
labels: Tensor(ones: [realLogits.shape[0], 1]))
let fakeLoss = sigmoidCrossEntropy(logits: fakeLogits,
labels: Tensor(zeros: [fakeLogits.shape[0], 1]))
return realLoss + fakeLoss
}

/// Sample noise vectors.
func sampleVector(size: Int) -> Tensor<Float> {
Tensor<Float>(randomNormal: [size, latentDim])
}

// MNIST data logic
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// MNIST data logic
// MNIST data logic

func minibatch<Scalar>(in x: Tensor<Scalar>, at index: Int) -> Tensor<Scalar> {
let start = index * batchSize
return x[start..<start+batchSize]
}

let (images, numericLabels) = readMNIST(imagesFile: "train-images-idx3-ubyte",
labelsFile: "train-labels-idx1-ubyte")
let labels = Tensor<Float>(oneHotAtIndices: numericLabels, depth: 10)

var generator = Generator()
var discriminator = Discriminator()

let optG = Adam(for: generator, learningRate: 2e-4, beta1: 0.5)
let optD = Adam(for: discriminator, learningRate: 2e-4, beta1: 0.5)

// noise for testing and plot function
let testImageGridSize = 4
let testVector = sampleVector(size: testImageGridSize*testImageGridSize)
func plotTestImage(_ testImage: Tensor<Float>, name: String) {
var imageGrid = testImage.reshaped(to: [testImageGridSize, testImageGridSize, imageHeight, imageWidth])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove redundant empty line.

// Add padding
imageGrid = imageGrid.padded(forSizes: [(0, 0), (0, 0), (1, 1), (1, 1)], with: 1)

// Transpose to create single image.
imageGrid = imageGrid.transposed(withPermutations: [0, 2, 1, 3])
imageGrid = imageGrid.reshaped(to: [(imageHeight+2)*testImageGridSize, (imageWidth+2)*testImageGridSize])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove redundant empty line.

// [-1, 1] range to [0, 1] range
imageGrid = (imageGrid + 1) / 2

plot(image: imageGrid, name: name)
}

print("Start training...")

// Training loop
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Training loop
// Start training loop.

for epoch in 1...epochCount {
// Training phase
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Training phase
// Start training phase.

Context.local.learningPhase = .training
for i in 0 ..< Int(labels.shape[0]) / batchSize {
// Alternative update
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Alternative update
// Perform alternative update.


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change

// Update Generator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Update Generator
// Update generator.

let vec1 = sampleVector(size: batchSize)

let 𝛁generator = generator.gradient { generator -> Tensor<Float> in
let fakeImages = generator(vec1)
let fakeLogits = discriminator(fakeImages)
let loss = generatorLossFunc(fakeLogits: fakeLogits)
return loss
}
optG.update(&generator.allDifferentiableVariables, along: 𝛁generator)

// Update Discriminator
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Update Discriminator
// Update discriminator.

let realImages = minibatch(in: images, at: i)
let vec2 = sampleVector(size: batchSize)
let fakeImages = generator(vec2)

let 𝛁discriminator = discriminator.gradient { discriminator -> Tensor<Float> in
let realLogits = discriminator(realImages)
let fakeLogits = discriminator(fakeImages)
let loss = discriminatorLossFunc(realLogits: realLogits, fakeLogits: fakeLogits)
return loss
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove redundant empty line.

optD.update(&discriminator.allDifferentiableVariables, along: 𝛁discriminator)
}

// Inference phase
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Inference phase
// Start inference phase.

Context.local.learningPhase = .inference
let testImage: Tensor<Float> = generator(testVector)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let testImage: Tensor<Float> = generator(testVector)
let testImage = generator(testVector)

Remove redundant type signature


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove redundant empty line.

plotTestImage(testImage, name: "epoch-\(epoch)-output")

let lossG = generatorLossFunc(fakeLogits: testImage)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove redundant empty line.

print("[Epoch: \(epoch)] Loss-G: \(lossG)")
}
2 changes: 2 additions & 0 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ let package = Package(
.executable(name: "ResNet", targets: ["ResNet"]),
.executable(name: "MiniGoDemo", targets: ["MiniGoDemo"]),
.library(name: "MiniGo", targets: ["MiniGo"]),
.executable(name: "GAN", targets: ["GAN"]),
],
targets: [
.target(name: "Autoencoder", path: "Autoencoder"),
Expand All @@ -29,5 +30,6 @@ let package = Package(
.testTarget(name: "MiniGoTests", dependencies: ["MiniGo"]),
.target(name: "ResNet", path: "ResNet"),
.target(name: "Transformer", path: "Transformer"),
.target(name: "GAN", path: "GAN"),
]
)