Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
.swiftpm
cifar-10-batches-py/
cifar-10-batches-bin/
output/
2 changes: 0 additions & 2 deletions Autoencoder/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ To begin, you'll need the [latest version of Swift for
TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
installed. Make sure you've added the correct version of `swift` to your path.

This example requires Matplotlib and NumPy to be installed, for use in image output.

To train the model, run:

```
Expand Down
56 changes: 20 additions & 36 deletions Autoencoder/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,59 +12,38 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import Datasets
import Foundation
import ModelSupport
import TensorFlow
import Python
import Datasets

// Import Python modules
let matplotlib = Python.import("matplotlib")
let np = Python.import("numpy")

// Use the AGG renderer for saving images to disk.
matplotlib.use("Agg")

let plt = Python.import("matplotlib.pyplot")

let epochCount = 10
let batchSize = 100
let outputFolder = "./output/"
let imageHeight = 28, imageWidth = 28
let imageHeight = 28
let imageWidth = 28

func plot(image: [Float], name: String) {
// Create figure
let ax = plt.gca()
let array = np.array([image])
let pixels = array.reshape([imageHeight, imageWidth])
if !FileManager.default.fileExists(atPath: outputFolder) {
try! FileManager.default.createDirectory(atPath: outputFolder,
withIntermediateDirectories: false,
attributes: nil)
}
ax.imshow(pixels, cmap: "gray")
plt.savefig("\(outputFolder)\(name).png", dpi: 300)
plt.close()
}
let outputFolder = "./output/"

/// An autoencoder.
struct Autoencoder: Layer {
typealias Input = Tensor<Float>
typealias Output = Tensor<Float>

var encoder1 = Dense<Float>(inputSize: imageHeight * imageWidth, outputSize: 128,
var encoder1 = Dense<Float>(
inputSize: imageHeight * imageWidth, outputSize: 128,
activation: relu)

var encoder2 = Dense<Float>(inputSize: 128, outputSize: 64, activation: relu)
var encoder3 = Dense<Float>(inputSize: 64, outputSize: 12, activation: relu)
var encoder4 = Dense<Float>(inputSize: 12, outputSize: 3, activation: relu)

var decoder1 = Dense<Float>(inputSize: 3, outputSize: 12, activation: relu)
var decoder2 = Dense<Float>(inputSize: 12, outputSize: 64, activation: relu)
var decoder3 = Dense<Float>(inputSize: 64, outputSize: 128, activation: relu)
var decoder4 = Dense<Float>(inputSize: 128, outputSize: imageHeight * imageWidth,

var decoder4 = Dense<Float>(
inputSize: 128, outputSize: imageHeight * imageWidth,
activation: tanh)

@differentiable
func callAsFunction(_ input: Input) -> Output {
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
let encoder = input.sequenced(through: encoder1, encoder2, encoder3, encoder4)
return encoder.sequenced(through: decoder1, decoder2, decoder3, decoder4)
}
Expand All @@ -76,11 +55,16 @@ let optimizer = RMSProp(for: autoencoder)

// Training loop
for epoch in 1...epochCount {
let sampleImage = Tensor(shape: [1, imageHeight * imageWidth], scalars: dataset.trainingImages[epoch].scalars)
let sampleImage = Tensor(
shape: [1, imageHeight * imageWidth], scalars: dataset.trainingImages[epoch].scalars)
let testImage = autoencoder(sampleImage)

plot(image: sampleImage.scalars, name: "epoch-\(epoch)-input")
plot(image: testImage.scalars, name: "epoch-\(epoch)-output")
saveImage(
tensor: sampleImage, size: (imageWidth, imageHeight), directory: outputFolder,
name: "epoch-\(epoch)-input")
saveImage(
tensor: testImage, size: (imageWidth, imageHeight), directory: outputFolder,
name: "epoch-\(epoch)-output")

let sampleLoss = meanSquaredError(predicted: testImage, expected: sampleImage)
print("[Epoch: \(epoch)] Loss: \(sampleLoss)")
Expand Down
2 changes: 0 additions & 2 deletions GAN/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ To begin, you'll need the [latest version of Swift for
TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md)
installed. Make sure you've added the correct version of `swift` to your path.

This example requires Matplotlib and NumPy to be installed, for use in image output.

To train the model, run:

```sh
Expand Down
133 changes: 68 additions & 65 deletions GAN/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,59 +12,42 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import Datasets
import Foundation
import ModelSupport
import TensorFlow
import Python
import Datasets

// Import Python modules.
let matplotlib = Python.import("matplotlib")
let np = Python.import("numpy")

// Use the AGG renderer for saving images to disk.
matplotlib.use("Agg")

let plt = Python.import("matplotlib.pyplot")

let epochCount = 10
let batchSize = 32
let outputFolder = "./output/"
let imageHeight = 28, imageWidth = 28
let imageHeight = 28
let imageWidth = 28
let imageSize = imageHeight * imageWidth
let latentSize = 64

func plotImage(_ image: Tensor<Float>, name: String) {
// Create figure.
let ax = plt.gca()
let array = np.array([image.scalars])
let pixels = array.reshape(image.shape)
if !FileManager.default.fileExists(atPath: outputFolder) {
try! FileManager.default.createDirectory(
atPath: outputFolder,
withIntermediateDirectories: false,
attributes: nil)
}
ax.imshow(pixels, cmap: "gray")
plt.savefig("\(outputFolder)\(name).png", dpi: 300)
plt.close()
}

// Models

struct Generator: Layer {
var dense1 = Dense<Float>(inputSize: latentSize, outputSize: latentSize * 2,
activation: { leakyRelu($0) })
var dense2 = Dense<Float>(inputSize: latentSize * 2, outputSize: latentSize * 4,
activation: { leakyRelu($0) })
var dense3 = Dense<Float>(inputSize: latentSize * 4, outputSize: latentSize * 8,
activation: { leakyRelu($0) })
var dense4 = Dense<Float>(inputSize: latentSize * 8, outputSize: imageSize,
activation: tanh)

var dense1 = Dense<Float>(
inputSize: latentSize, outputSize: latentSize * 2,
activation: { leakyRelu($0) })

var dense2 = Dense<Float>(
inputSize: latentSize * 2, outputSize: latentSize * 4,
activation: { leakyRelu($0) })

var dense3 = Dense<Float>(
inputSize: latentSize * 4, outputSize: latentSize * 8,
activation: { leakyRelu($0) })

var dense4 = Dense<Float>(
inputSize: latentSize * 8, outputSize: imageSize,
activation: tanh)

var batchnorm1 = BatchNorm<Float>(featureCount: latentSize * 2)
var batchnorm2 = BatchNorm<Float>(featureCount: latentSize * 4)
var batchnorm3 = BatchNorm<Float>(featureCount: latentSize * 8)

@differentiable
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
let x1 = batchnorm1(dense1(input))
Expand All @@ -75,15 +58,22 @@ struct Generator: Layer {
}

struct Discriminator: Layer {
var dense1 = Dense<Float>(inputSize: imageSize, outputSize: 256,
activation: { leakyRelu($0) })
var dense2 = Dense<Float>(inputSize: 256, outputSize: 64,
activation: { leakyRelu($0) })
var dense3 = Dense<Float>(inputSize: 64, outputSize: 16,
activation: { leakyRelu($0) })
var dense4 = Dense<Float>(inputSize: 16, outputSize: 1,
activation: identity)

var dense1 = Dense<Float>(
inputSize: imageSize, outputSize: 256,
activation: { leakyRelu($0) })

var dense2 = Dense<Float>(
inputSize: 256, outputSize: 64,
activation: { leakyRelu($0) })

var dense3 = Dense<Float>(
inputSize: 64, outputSize: 16,
activation: { leakyRelu($0) })

var dense4 = Dense<Float>(
inputSize: 16, outputSize: 1,
activation: identity)

@differentiable
func callAsFunction(_ input: Tensor<Float>) -> Tensor<Float> {
input.sequenced(through: dense1, dense2, dense3, dense4)
Expand All @@ -94,16 +84,19 @@ struct Discriminator: Layer {

@differentiable
func generatorLoss(fakeLogits: Tensor<Float>) -> Tensor<Float> {
sigmoidCrossEntropy(logits: fakeLogits,
labels: Tensor(ones: fakeLogits.shape))
sigmoidCrossEntropy(
logits: fakeLogits,
labels: Tensor(ones: fakeLogits.shape))
}

@differentiable
func discriminatorLoss(realLogits: Tensor<Float>, fakeLogits: Tensor<Float>) -> Tensor<Float> {
let realLoss = sigmoidCrossEntropy(logits: realLogits,
labels: Tensor(ones: realLogits.shape))
let fakeLoss = sigmoidCrossEntropy(logits: fakeLogits,
labels: Tensor(zeros: fakeLogits.shape))
let realLoss = sigmoidCrossEntropy(
logits: realLogits,
labels: Tensor(ones: realLogits.shape))
let fakeLoss = sigmoidCrossEntropy(
logits: fakeLogits,
labels: Tensor(zeros: fakeLogits.shape))
return realLoss + fakeLoss
}

Expand All @@ -123,18 +116,28 @@ let optD = Adam(for: discriminator, learningRate: 2e-4, beta1: 0.5)
// Noise vectors and plot function for testing
let testImageGridSize = 4
let testVector = sampleVector(size: testImageGridSize * testImageGridSize)
func plotTestImage(_ testImage: Tensor<Float>, name: String) {
var gridImage = testImage.reshaped(to: [testImageGridSize, testImageGridSize,
imageHeight, imageWidth])

func saveImageGrid(_ testImage: Tensor<Float>, name: String) {
var gridImage = testImage.reshaped(
to: [
testImageGridSize, testImageGridSize,
imageHeight, imageWidth
])
// Add padding.
gridImage = gridImage.padded(forSizes: [(0, 0), (0, 0), (1, 1), (1, 1)], with: 1)
// Transpose to create single image.
gridImage = gridImage.transposed(withPermutations: [0, 2, 1, 3])
gridImage = gridImage.reshaped(to: [(imageHeight + 2) * testImageGridSize,
(imageWidth + 2) * testImageGridSize])
gridImage = gridImage.reshaped(
to: [
(imageHeight + 2) * testImageGridSize,
(imageWidth + 2) * testImageGridSize
])
// Convert [-1, 1] range to [0, 1] range.
gridImage = (gridImage + 1) / 2
plotImage(gridImage, name: name)

saveImage(
tensor: gridImage, size: (gridImage.shape[0], gridImage.shape[1]), directory: outputFolder,
name: name)
}

print("Start training...")
Expand All @@ -147,20 +150,20 @@ for epoch in 1...epochCount {
// Perform alternative update.
// Update generator.
let vec1 = sampleVector(size: batchSize)

let 𝛁generator = generator.gradient { generator -> Tensor<Float> in
let fakeImages = generator(vec1)
let fakeLogits = discriminator(fakeImages)
let loss = generatorLoss(fakeLogits: fakeLogits)
return loss
}
optG.update(&generator, along: 𝛁generator)

// Update discriminator.
let realImages = dataset.trainingImages.minibatch(at: i, batchSize: batchSize)
let vec2 = sampleVector(size: batchSize)
let fakeImages = generator(vec2)

let 𝛁discriminator = discriminator.gradient { discriminator -> Tensor<Float> in
let realLogits = discriminator(realImages)
let fakeLogits = discriminator(fakeImages)
Expand All @@ -169,12 +172,12 @@ for epoch in 1...epochCount {
}
optD.update(&discriminator, along: 𝛁discriminator)
}

// Start inference phase.
Context.local.learningPhase = .inference
let testImage = generator(testVector)
plotTestImage(testImage, name: "epoch-\(epoch)-output")
saveImageGrid(testImage, name: "epoch-\(epoch)-output")

let lossG = generatorLoss(fakeLogits: testImage)
print("[Epoch: \(epoch)] Loss-G: \(lossG)")
}
6 changes: 4 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ let package = Package(
products: [
.library(name: "ImageClassificationModels", targets: ["ImageClassificationModels"]),
.library(name: "Datasets", targets: ["Datasets"]),
.library(name: "ModelSupport", targets: ["ModelSupport"]),
.executable(name: "Custom-CIFAR10", targets: ["Custom-CIFAR10"]),
.executable(name: "ResNet-CIFAR10", targets: ["ResNet-CIFAR10"]),
.executable(name: "LeNet-MNIST", targets: ["LeNet-MNIST"]),
Expand All @@ -21,7 +22,8 @@ let package = Package(
targets: [
.target(name: "ImageClassificationModels", path: "Models/ImageClassification"),
.target(name: "Datasets", path: "Datasets"),
.target(name: "Autoencoder", dependencies: ["Datasets"], path: "Autoencoder"),
.target(name: "ModelSupport", path: "Support"),
.target(name: "Autoencoder", dependencies: ["Datasets", "ModelSupport"], path: "Autoencoder"),
.target(name: "Catch", path: "Catch"),
.target(name: "Gym-FrozenLake", path: "Gym/FrozenLake"),
.target(name: "Gym-CartPole", path: "Gym/CartPole"),
Expand All @@ -41,6 +43,6 @@ let package = Package(
sources: ["main.swift"]),
.testTarget(name: "MiniGoTests", dependencies: ["MiniGo"]),
.target(name: "Transformer", path: "Transformer"),
.target(name: "GAN", dependencies: ["Datasets"], path: "GAN"),
.target(name: "GAN", dependencies: ["Datasets", "ModelSupport"], path: "GAN"),
]
)
24 changes: 24 additions & 0 deletions Support/FileManagement.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

public func createDirectoryIfMissing(path: String) {
if !FileManager.default.fileExists(atPath: path) {
try! FileManager.default.createDirectory(
atPath: path,
withIntermediateDirectories: false,
attributes: nil)
}
}
Loading