-
Notifications
You must be signed in to change notification settings - Fork 150
Add GAN Example #181
Add GAN Example #181
Changes from 17 commits
b16005c
19ba68d
e6edc38
084510e
c70ad3d
e52b0ec
813342d
c00416f
799a3eb
6097719
b218e51
44a1904
39e997a
9a79072
83bdb72
d1ec2bb
4aac95f
e4d3540
dff62cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,30 @@ | ||||||||
| # Simple GAN | ||||||||
|
|
||||||||
| ### After Epoch 1 | ||||||||
| <p align="center"> | ||||||||
| <img src="images/epoch-1-output.png" height="270" width="360"> | ||||||||
| </p> | ||||||||
|
|
||||||||
| ### After Epoch 10 | ||||||||
|
||||||||
| ### After Epoch 10 | |
| After Epoch 10: | |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| ``` | |
| ```sh |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Not sure what you mean exactly by "modules".
- I feel lines 24-30 are unnecessary since they are just describing an installation step that's required for the entire models repository.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It came from Autoencoder/README.md.
I just copied it and didn't consider much.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah ok, we should unify the documentation in all models. @BradLarson how about having a standard README template?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rxwei - We definitely will need to rework the READMEs across the examples. Varying levels of information is provided in each, as well as the language used for describing the models. A template for examples would be much appreciated, as well as a good central listing of them. The original Autoencoder README was a little rough (and I think the Python information was confusing and missing the need for matplotlib, etc.), thus the issues with this one derived from it.
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,211 @@ | ||||||||
| // Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||||||||
| // | ||||||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||||||
| // you may not use this file except in compliance with the License. | ||||||||
| // You may obtain a copy of the License at | ||||||||
| // | ||||||||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||||||||
| // | ||||||||
| // Unless required by applicable law or agreed to in writing, software | ||||||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||||
| // See the License for the specific language governing permissions and | ||||||||
| // limitations under the License. | ||||||||
|
|
||||||||
| import Foundation | ||||||||
| import TensorFlow | ||||||||
| import Python | ||||||||
|
|
||||||||
| // Import Python modules. | ||||||||
| let matplotlib = Python.import("matplotlib") | ||||||||
| let np = Python.import("numpy") | ||||||||
| let plt = Python.import("matplotlib.pyplot") | ||||||||
|
|
||||||||
| // Turn off using display on server / Linux. | ||||||||
| matplotlib.use("Agg") | ||||||||
|
|
||||||||
| let epochCount = 10 | ||||||||
| let batchSize = 32 | ||||||||
| let outputFolder = "./output/" | ||||||||
| let imageHeight = 28, imageWidth = 28 | ||||||||
| let imageSize = imageHeight * imageWidth | ||||||||
| let latentSize = 64 | ||||||||
|
|
||||||||
| func plotImage(_ image: Tensor<Float>, name: String) { | ||||||||
| // Create figure. | ||||||||
| let ax = plt.gca() | ||||||||
| let array = np.array([image.scalars]) | ||||||||
| let pixels = array.reshape(image.shape) | ||||||||
| if !FileManager.default.fileExists(atPath: outputFolder) { | ||||||||
| try! FileManager.default.createDirectory( | ||||||||
| atPath: outputFolder, | ||||||||
| withIntermediateDirectories: false, | ||||||||
| attributes: nil) | ||||||||
| } | ||||||||
| ax.imshow(pixels, cmap: "gray") | ||||||||
| plt.savefig("\(outputFolder)\(name).png", dpi: 300) | ||||||||
| plt.close() | ||||||||
| } | ||||||||
|
|
||||||||
| /// Reads a file into an array of bytes. | ||||||||
| func readFile(_ filename: String) -> [UInt8] { | ||||||||
| let possibleFolders = [".", "Resources", "GAN/Resources"] | ||||||||
| for folder in possibleFolders { | ||||||||
| let parent = URL(fileURLWithPath: folder) | ||||||||
| let filePath = parent.appendingPathComponent(filename).path | ||||||||
| guard FileManager.default.fileExists(atPath: filePath) else { | ||||||||
| continue | ||||||||
| } | ||||||||
| let d = Python.open(filePath, "rb").read() | ||||||||
| return Array(numpy: np.frombuffer(d, dtype: np.uint8))! | ||||||||
| } | ||||||||
| print("Failed to find file with name \(filename) in the following folders: \(possibleFolders).") | ||||||||
| exit(-1) | ||||||||
| } | ||||||||
|
|
||||||||
| /// Reads MNIST images from specified file path. | ||||||||
| func readMNIST(imagesFile: String) -> Tensor<Float> { | ||||||||
| print("Reading data.") | ||||||||
| let images = readFile(imagesFile).dropFirst(16).map { Float($0) } | ||||||||
| let rowCount = images.count / imageSize | ||||||||
|
|
||||||||
|
||||||||
| print("Constructing data tensors.") | ||||||||
| return Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images) / 255.0 * 2 - 1 | ||||||||
| } | ||||||||
|
|
||||||||
| // Models | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add an empty line, since this is describing both
Suggested change
|
||||||||
|
|
||||||||
| struct Generator: Layer { | ||||||||
| var dense1 = Dense<Float>(inputSize: latentSize, outputSize: latentSize * 2, | ||||||||
| activation: { leakyRelu($0) }) | ||||||||
| var dense2 = Dense<Float>(inputSize: latentSize * 2, outputSize: latentSize * 4, | ||||||||
| activation: { leakyRelu($0) }) | ||||||||
| var dense3 = Dense<Float>(inputSize: latentSize * 4, outputSize: latentSize * 8, | ||||||||
| activation: { leakyRelu($0) }) | ||||||||
| var dense4 = Dense<Float>(inputSize: latentSize * 8, outputSize: imageSize, | ||||||||
| activation: tanh) | ||||||||
|
|
||||||||
| var batchnorm1 = BatchNorm<Float>(featureCount: latentSize * 2) | ||||||||
| var batchnorm2 = BatchNorm<Float>(featureCount: latentSize * 4) | ||||||||
| var batchnorm3 = BatchNorm<Float>(featureCount: latentSize * 8) | ||||||||
|
|
||||||||
| @differentiable | ||||||||
| func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> { | ||||||||
| let x1 = batchnorm1(dense1(input)) | ||||||||
| let x2 = batchnorm2(dense2(x1)) | ||||||||
| let x3 = batchnorm3(dense3(x2)) | ||||||||
| return dense4(x3) | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| struct Discriminator: Layer { | ||||||||
| var dense1 = Dense<Float>(inputSize: imageSize, outputSize: 256, | ||||||||
| activation: { leakyRelu($0) }) | ||||||||
| var dense2 = Dense<Float>(inputSize: 256, outputSize: 64, | ||||||||
| activation: { leakyRelu($0) }) | ||||||||
| var dense3 = Dense<Float>(inputSize: 64, outputSize: 16, | ||||||||
| activation: { leakyRelu($0) }) | ||||||||
| var dense4 = Dense<Float>(inputSize: 16, outputSize: 1, | ||||||||
| activation: identity) | ||||||||
|
|
||||||||
| @differentiable | ||||||||
| func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> { | ||||||||
| input.sequenced(through: dense1, dense2, dense3, dense4) | ||||||||
| } | ||||||||
| } | ||||||||
|
|
||||||||
| // Loss functions | ||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
|
|
||||||||
| @differentiable | ||||||||
| func generatorLoss(fakeLogits: Tensor<Float>) -> Tensor<Float> { | ||||||||
| sigmoidCrossEntropy(logits: fakeLogits, | ||||||||
| labels: Tensor(ones: fakeLogits.shape)) | ||||||||
| } | ||||||||
|
|
||||||||
| @differentiable | ||||||||
| func discriminatorLoss(realLogits: Tensor<Float>, fakeLogits: Tensor<Float>) -> Tensor<Float> { | ||||||||
| let realLoss = sigmoidCrossEntropy(logits: realLogits, | ||||||||
| labels: Tensor(ones: realLogits.shape)) | ||||||||
| let fakeLoss = sigmoidCrossEntropy(logits: fakeLogits, | ||||||||
| labels: Tensor(zeros: fakeLogits.shape)) | ||||||||
| return realLoss + fakeLoss | ||||||||
| } | ||||||||
|
|
||||||||
| /// Returns `size` samples of noise vector. | ||||||||
| func sampleVector(size: Int) -> Tensor<Float> { | ||||||||
| Tensor<Float>(randomNormal: [size, latentSize]) | ||||||||
|
||||||||
| Tensor<Float>(randomNormal: [size, latentSize]) | |
| Tensor(randomNormal: [size, latentSize]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| // MNIST data logic | |
| // MNIST data logic | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.