diff --git a/.gitignore b/.gitignore index 4ecfa3c6745..4fea2e7e10e 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ .swiftpm cifar-10-batches-py/ cifar-10-batches-bin/ +output/ diff --git a/Autoencoder/README.md b/Autoencoder/README.md index 2c7c25c9b61..a635a82cc93 100644 --- a/Autoencoder/README.md +++ b/Autoencoder/README.md @@ -21,8 +21,6 @@ To begin, you'll need the [latest version of Swift for TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md) installed. Make sure you've added the correct version of `swift` to your path. -This example requires Matplotlib and NumPy to be installed, for use in image output. - To train the model, run: ``` diff --git a/Autoencoder/main.swift b/Autoencoder/main.swift index 5d4191a72d5..2f478b65770 100644 --- a/Autoencoder/main.swift +++ b/Autoencoder/main.swift @@ -12,47 +12,24 @@ // See the License for the specific language governing permissions and // limitations under the License. +import Datasets import Foundation +import ModelSupport import TensorFlow -import Python -import Datasets - -// Import Python modules -let matplotlib = Python.import("matplotlib") -let np = Python.import("numpy") - -// Use the AGG renderer for saving images to disk. -matplotlib.use("Agg") - -let plt = Python.import("matplotlib.pyplot") let epochCount = 10 let batchSize = 100 -let outputFolder = "./output/" -let imageHeight = 28, imageWidth = 28 +let imageHeight = 28 +let imageWidth = 28 -func plot(image: [Float], name: String) { - // Create figure - let ax = plt.gca() - let array = np.array([image]) - let pixels = array.reshape([imageHeight, imageWidth]) - if !FileManager.default.fileExists(atPath: outputFolder) { - try! FileManager.default.createDirectory(atPath: outputFolder, - withIntermediateDirectories: false, - attributes: nil) - } - ax.imshow(pixels, cmap: "gray") - plt.savefig("\(outputFolder)\(name).png", dpi: 300) - plt.close() -} +let outputFolder = "./output/" /// An autoencoder. struct Autoencoder: Layer { - typealias Input = Tensor - typealias Output = Tensor - - var encoder1 = Dense(inputSize: imageHeight * imageWidth, outputSize: 128, + var encoder1 = Dense( + inputSize: imageHeight * imageWidth, outputSize: 128, activation: relu) + var encoder2 = Dense(inputSize: 128, outputSize: 64, activation: relu) var encoder3 = Dense(inputSize: 64, outputSize: 12, activation: relu) var encoder4 = Dense(inputSize: 12, outputSize: 3, activation: relu) @@ -60,11 +37,13 @@ struct Autoencoder: Layer { var decoder1 = Dense(inputSize: 3, outputSize: 12, activation: relu) var decoder2 = Dense(inputSize: 12, outputSize: 64, activation: relu) var decoder3 = Dense(inputSize: 64, outputSize: 128, activation: relu) - var decoder4 = Dense(inputSize: 128, outputSize: imageHeight * imageWidth, + + var decoder4 = Dense( + inputSize: 128, outputSize: imageHeight * imageWidth, activation: tanh) @differentiable - func callAsFunction(_ input: Input) -> Output { + func callAsFunction(_ input: Tensor) -> Tensor { let encoder = input.sequenced(through: encoder1, encoder2, encoder3, encoder4) return encoder.sequenced(through: decoder1, decoder2, decoder3, decoder4) } @@ -76,11 +55,20 @@ let optimizer = RMSProp(for: autoencoder) // Training loop for epoch in 1...epochCount { - let sampleImage = Tensor(shape: [1, imageHeight * imageWidth], scalars: dataset.trainingImages[epoch].scalars) + let sampleImage = Tensor( + shape: [1, imageHeight * imageWidth], scalars: dataset.trainingImages[epoch].scalars) let testImage = autoencoder(sampleImage) - plot(image: sampleImage.scalars, name: "epoch-\(epoch)-input") - plot(image: testImage.scalars, name: "epoch-\(epoch)-output") + do { + try saveImage( + sampleImage, size: (imageWidth, imageHeight), directory: outputFolder, + name: "epoch-\(epoch)-input") + try saveImage( + testImage, size: (imageWidth, imageHeight), directory: outputFolder, + name: "epoch-\(epoch)-output") + } catch { + print("Could not save image with error: \(error)") + } let sampleLoss = meanSquaredError(predicted: testImage, expected: sampleImage) print("[Epoch: \(epoch)] Loss: \(sampleLoss)") diff --git a/GAN/README.md b/GAN/README.md index 9b6d1e24c18..0482aa75999 100644 --- a/GAN/README.md +++ b/GAN/README.md @@ -16,8 +16,6 @@ To begin, you'll need the [latest version of Swift for TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md) installed. Make sure you've added the correct version of `swift` to your path. -This example requires Matplotlib and NumPy to be installed, for use in image output. - To train the model, run: ```sh diff --git a/GAN/main.swift b/GAN/main.swift index 05296f6b65d..f57a550274f 100644 --- a/GAN/main.swift +++ b/GAN/main.swift @@ -12,59 +12,42 @@ // See the License for the specific language governing permissions and // limitations under the License. +import Datasets import Foundation +import ModelSupport import TensorFlow -import Python -import Datasets - -// Import Python modules. -let matplotlib = Python.import("matplotlib") -let np = Python.import("numpy") - -// Use the AGG renderer for saving images to disk. -matplotlib.use("Agg") - -let plt = Python.import("matplotlib.pyplot") let epochCount = 10 let batchSize = 32 let outputFolder = "./output/" -let imageHeight = 28, imageWidth = 28 +let imageHeight = 28 +let imageWidth = 28 let imageSize = imageHeight * imageWidth let latentSize = 64 -func plotImage(_ image: Tensor, name: String) { - // Create figure. - let ax = plt.gca() - let array = np.array([image.scalars]) - let pixels = array.reshape(image.shape) - if !FileManager.default.fileExists(atPath: outputFolder) { - try! FileManager.default.createDirectory( - atPath: outputFolder, - withIntermediateDirectories: false, - attributes: nil) - } - ax.imshow(pixels, cmap: "gray") - plt.savefig("\(outputFolder)\(name).png", dpi: 300) - plt.close() -} - // Models struct Generator: Layer { - var dense1 = Dense(inputSize: latentSize, outputSize: latentSize * 2, - activation: { leakyRelu($0) }) - var dense2 = Dense(inputSize: latentSize * 2, outputSize: latentSize * 4, - activation: { leakyRelu($0) }) - var dense3 = Dense(inputSize: latentSize * 4, outputSize: latentSize * 8, - activation: { leakyRelu($0) }) - var dense4 = Dense(inputSize: latentSize * 8, outputSize: imageSize, - activation: tanh) - + var dense1 = Dense( + inputSize: latentSize, outputSize: latentSize * 2, + activation: { leakyRelu($0) }) + + var dense2 = Dense( + inputSize: latentSize * 2, outputSize: latentSize * 4, + activation: { leakyRelu($0) }) + + var dense3 = Dense( + inputSize: latentSize * 4, outputSize: latentSize * 8, + activation: { leakyRelu($0) }) + + var dense4 = Dense( + inputSize: latentSize * 8, outputSize: imageSize, + activation: tanh) + var batchnorm1 = BatchNorm(featureCount: latentSize * 2) var batchnorm2 = BatchNorm(featureCount: latentSize * 4) var batchnorm3 = BatchNorm(featureCount: latentSize * 8) - + @differentiable func callAsFunction(_ input: Tensor) -> Tensor { let x1 = batchnorm1(dense1(input)) @@ -75,15 +58,22 @@ struct Generator: Layer { } struct Discriminator: Layer { - var dense1 = Dense(inputSize: imageSize, outputSize: 256, - activation: { leakyRelu($0) }) - var dense2 = Dense(inputSize: 256, outputSize: 64, - activation: { leakyRelu($0) }) - var dense3 = Dense(inputSize: 64, outputSize: 16, - activation: { leakyRelu($0) }) - var dense4 = Dense(inputSize: 16, outputSize: 1, - activation: identity) - + var dense1 = Dense( + inputSize: imageSize, outputSize: 256, + activation: { leakyRelu($0) }) + + var dense2 = Dense( + inputSize: 256, outputSize: 64, + activation: { leakyRelu($0) }) + + var dense3 = Dense( + inputSize: 64, outputSize: 16, + activation: { leakyRelu($0) }) + + var dense4 = Dense( + inputSize: 16, outputSize: 1, + activation: identity) + @differentiable func callAsFunction(_ input: Tensor) -> Tensor { input.sequenced(through: dense1, dense2, dense3, dense4) @@ -94,16 +84,19 @@ struct Discriminator: Layer { @differentiable func generatorLoss(fakeLogits: Tensor) -> Tensor { - sigmoidCrossEntropy(logits: fakeLogits, - labels: Tensor(ones: fakeLogits.shape)) + sigmoidCrossEntropy( + logits: fakeLogits, + labels: Tensor(ones: fakeLogits.shape)) } @differentiable func discriminatorLoss(realLogits: Tensor, fakeLogits: Tensor) -> Tensor { - let realLoss = sigmoidCrossEntropy(logits: realLogits, - labels: Tensor(ones: realLogits.shape)) - let fakeLoss = sigmoidCrossEntropy(logits: fakeLogits, - labels: Tensor(zeros: fakeLogits.shape)) + let realLoss = sigmoidCrossEntropy( + logits: realLogits, + labels: Tensor(ones: realLogits.shape)) + let fakeLoss = sigmoidCrossEntropy( + logits: fakeLogits, + labels: Tensor(zeros: fakeLogits.shape)) return realLoss + fakeLoss } @@ -123,18 +116,28 @@ let optD = Adam(for: discriminator, learningRate: 2e-4, beta1: 0.5) // Noise vectors and plot function for testing let testImageGridSize = 4 let testVector = sampleVector(size: testImageGridSize * testImageGridSize) -func plotTestImage(_ testImage: Tensor, name: String) { - var gridImage = testImage.reshaped(to: [testImageGridSize, testImageGridSize, - imageHeight, imageWidth]) + +func saveImageGrid(_ testImage: Tensor, name: String) throws { + var gridImage = testImage.reshaped( + to: [ + testImageGridSize, testImageGridSize, + imageHeight, imageWidth, + ]) // Add padding. gridImage = gridImage.padded(forSizes: [(0, 0), (0, 0), (1, 1), (1, 1)], with: 1) // Transpose to create single image. gridImage = gridImage.transposed(withPermutations: [0, 2, 1, 3]) - gridImage = gridImage.reshaped(to: [(imageHeight + 2) * testImageGridSize, - (imageWidth + 2) * testImageGridSize]) + gridImage = gridImage.reshaped( + to: [ + (imageHeight + 2) * testImageGridSize, + (imageWidth + 2) * testImageGridSize, + ]) // Convert [-1, 1] range to [0, 1] range. gridImage = (gridImage + 1) / 2 - plotImage(gridImage, name: name) + + try saveImage( + gridImage, size: (gridImage.shape[0], gridImage.shape[1]), directory: outputFolder, + name: name) } print("Start training...") @@ -147,7 +150,7 @@ for epoch in 1...epochCount { // Perform alternative update. // Update generator. let vec1 = sampleVector(size: batchSize) - + let 𝛁generator = generator.gradient { generator -> Tensor in let fakeImages = generator(vec1) let fakeLogits = discriminator(fakeImages) @@ -155,12 +158,12 @@ for epoch in 1...epochCount { return loss } optG.update(&generator, along: 𝛁generator) - + // Update discriminator. let realImages = dataset.trainingImages.minibatch(at: i, batchSize: batchSize) let vec2 = sampleVector(size: batchSize) let fakeImages = generator(vec2) - + let 𝛁discriminator = discriminator.gradient { discriminator -> Tensor in let realLogits = discriminator(realImages) let fakeLogits = discriminator(fakeImages) @@ -169,12 +172,17 @@ for epoch in 1...epochCount { } optD.update(&discriminator, along: 𝛁discriminator) } - + // Start inference phase. Context.local.learningPhase = .inference let testImage = generator(testVector) - plotTestImage(testImage, name: "epoch-\(epoch)-output") - + + do { + try saveImageGrid(testImage, name: "epoch-\(epoch)-output") + } catch { + print("Could not save image grid with error: \(error)") + } + let lossG = generatorLoss(fakeLogits: testImage) print("[Epoch: \(epoch)] Loss-G: \(lossG)") } diff --git a/Package.swift b/Package.swift index 4d8ed439bcb..6d8dab911ea 100644 --- a/Package.swift +++ b/Package.swift @@ -11,6 +11,7 @@ let package = Package( products: [ .library(name: "ImageClassificationModels", targets: ["ImageClassificationModels"]), .library(name: "Datasets", targets: ["Datasets"]), + .library(name: "ModelSupport", targets: ["ModelSupport"]), .executable(name: "Custom-CIFAR10", targets: ["Custom-CIFAR10"]), .executable(name: "ResNet-CIFAR10", targets: ["ResNet-CIFAR10"]), .executable(name: "LeNet-MNIST", targets: ["LeNet-MNIST"]), @@ -21,7 +22,8 @@ let package = Package( targets: [ .target(name: "ImageClassificationModels", path: "Models/ImageClassification"), .target(name: "Datasets", path: "Datasets"), - .target(name: "Autoencoder", dependencies: ["Datasets"], path: "Autoencoder"), + .target(name: "ModelSupport", path: "Support"), + .target(name: "Autoencoder", dependencies: ["Datasets", "ModelSupport"], path: "Autoencoder"), .target(name: "Catch", path: "Catch"), .target(name: "Gym-FrozenLake", path: "Gym/FrozenLake"), .target(name: "Gym-CartPole", path: "Gym/CartPole"), @@ -41,6 +43,6 @@ let package = Package( sources: ["main.swift"]), .testTarget(name: "MiniGoTests", dependencies: ["MiniGo"]), .target(name: "Transformer", path: "Transformer"), - .target(name: "GAN", dependencies: ["Datasets"], path: "GAN"), + .target(name: "GAN", dependencies: ["Datasets", "ModelSupport"], path: "GAN"), ] ) diff --git a/Support/FileManagement.swift b/Support/FileManagement.swift new file mode 100644 index 00000000000..4a04147208d --- /dev/null +++ b/Support/FileManagement.swift @@ -0,0 +1,23 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +public func createDirectoryIfMissing(at path: String) throws { + guard !FileManager.default.fileExists(atPath: path) else { return } + try FileManager.default.createDirectory( + atPath: path, + withIntermediateDirectories: false, + attributes: nil) +} diff --git a/Support/Image.swift b/Support/Image.swift new file mode 100644 index 00000000000..8cef015a6c0 --- /dev/null +++ b/Support/Image.swift @@ -0,0 +1,90 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation +import TensorFlow + +public struct Image { + public enum ByteOrdering { + case bgr + case rgb + } + + enum ImageTensor { + case float(data: Tensor) + case uint8(data: Tensor) + } + + let imageData: ImageTensor + + public init(tensor: Tensor) { + self.imageData = .uint8(data: tensor) + } + + public init(tensor: Tensor) { + self.imageData = .float(data: tensor) + } + + public init(jpeg url: URL, byteOrdering: ByteOrdering = .rgb) { + let loadedFile = Raw.readFile(filename: StringTensor(url.absoluteString)) + let loadedJpeg = Raw.decodeJpeg(contents: loadedFile, channels: 3, dctMethod: "") + if byteOrdering == .bgr { + self.imageData = .uint8( + data: Raw.reverse(loadedJpeg, dims: Tensor([false, false, false, true]))) + } else { + self.imageData = .uint8(data: loadedJpeg) + } + } + + public func save(to url: URL, quality: Int64 = 95) { + // This currently only saves in grayscale. + let outputImageData: Tensor + switch self.imageData { + case let .uint8(data): outputImageData = data + case let .float(data): + let lowerBound = data.min(alongAxes: [0, 1]) + let upperBound = data.max(alongAxes: [0, 1]) + let adjustedData = (data - lowerBound) * (255.0 / (upperBound - lowerBound)) + outputImageData = Tensor(adjustedData) + } + + let encodedJpeg = Raw.encodeJpeg( + image: outputImageData, format: .grayscale, quality: quality, xmpMetadata: "") + Raw.writeFile(filename: StringTensor(url.absoluteString), contents: encodedJpeg) + } + + public func resized(to size: (Int, Int)) -> Image { + switch self.imageData { + case let .uint8(data): + return Image( + tensor: Raw.resizeBilinear( + images: Tensor([data]), + size: Tensor([Int32(size.0), Int32(size.1)]))) + case let .float(data): + return Image( + tensor: Raw.resizeBilinear( + images: Tensor([data]), + size: Tensor([Int32(size.0), Int32(size.1)]))) + } + + } +} + +public func saveImage(_ tensor: Tensor, size: (Int, Int), directory: String, name: String) throws { + try createDirectoryIfMissing(at: directory) + let reshapedTensor = tensor.reshaped(to: [size.0, size.1, 1]) + let image = Image(tensor: reshapedTensor) + let outputURL = URL(fileURLWithPath:"\(directory)\(name).jpg") + image.save(to: outputURL) +}