diff --git a/Autoencoder/Autoencoder1D/main.swift b/Autoencoder/Autoencoder1D/main.swift index 45cc0a2d664..2851d01aabb 100644 --- a/Autoencoder/Autoencoder1D/main.swift +++ b/Autoencoder/Autoencoder1D/main.swift @@ -24,7 +24,7 @@ let imageHeight = 28 let imageWidth = 28 let outputFolder = "./output/" -let dataset = FashionMNIST(batchSize: batchSize, flattening: true) +let dataset = OldFashionMNIST(batchSize: batchSize, flattening: true) // An autoencoder. var autoencoder = Sequential { // The encoder. diff --git a/Autoencoder/Autoencoder2D/main.swift b/Autoencoder/Autoencoder2D/main.swift index d8a83b14dcc..428c997a274 100644 --- a/Autoencoder/Autoencoder2D/main.swift +++ b/Autoencoder/Autoencoder2D/main.swift @@ -26,7 +26,7 @@ let imageHeight = 28 let imageWidth = 28 let outputFolder = "./output/" -let dataset = KuzushijiMNIST(batchSize: batchSize, flattening: true) +let dataset = OldKuzushijiMNIST(batchSize: batchSize, flattening: true) // An autoencoder. struct Autoencoder2D: Layer { diff --git a/Autoencoder/VAE1D/main.swift b/Autoencoder/VAE1D/main.swift index bc0756b7f70..c8e51165b0e 100644 --- a/Autoencoder/VAE1D/main.swift +++ b/Autoencoder/VAE1D/main.swift @@ -26,7 +26,7 @@ let imageHeight = 28 let imageWidth = 28 let outputFolder = "./output/" -let dataset = MNIST(batchSize: 128, flattening: true) +let dataset = OldMNIST(batchSize: 128, flattening: true) let inputDim = 784 // 28*28 for any MNIST let hiddenDim = 400 @@ -84,7 +84,7 @@ func vaeLossFunction( } // TODO: Find a cleaner way of extracting individual images that doesn't require a second dataset. -let singleImageDataset = MNIST(batchSize: 1, flattening: true) +let singleImageDataset = OldMNIST(batchSize: 1, flattening: true) let individualTestImages = singleImageDataset.test var testImageIterator = individualTestImages.sequenced() diff --git a/Benchmarks/Models/LeNetMnist.swift b/Benchmarks/Models/LeNetMnist.swift index 9fa6bb35b82..f2cc33be90c 100755 --- a/Benchmarks/Models/LeNetMnist.swift +++ b/Benchmarks/Models/LeNetMnist.swift @@ -37,11 +37,11 @@ enum LeNetMNIST: BenchmarkModel { } static func makeInferenceBenchmark(settings: BenchmarkSettings) -> Benchmark { - return ImageClassificationInference(settings: settings) + return ImageClassificationInference(settings: settings) } static func makeTrainingBenchmark(settings: BenchmarkSettings) -> Benchmark { - return ImageClassificationTraining(settings: settings) + return ImageClassificationTraining(settings: settings) } } diff --git a/DCGAN/main.swift b/DCGAN/main.swift index b8e0f99756f..d8813a6e8c6 100644 --- a/DCGAN/main.swift +++ b/DCGAN/main.swift @@ -18,7 +18,7 @@ import ModelSupport import TensorFlow let batchSize = 512 -let mnist = MNIST(batchSize: batchSize, flattening: false, normalizing: true) +let mnist = OldMNIST(batchSize: batchSize, flattening: false, normalizing: true) let outputFolder = "./output/" diff --git a/Datasets/CIFAR10/CIFAR10.swift b/Datasets/CIFAR10/CIFAR10.swift index b806d866e21..7b9e640e594 100644 --- a/Datasets/CIFAR10/CIFAR10.swift +++ b/Datasets/CIFAR10/CIFAR10.swift @@ -48,12 +48,13 @@ public struct CIFAR10 { self.init( batchSize: batchSize, entropy: entropy, + device: Device.default, remoteBinaryArchiveLocation: URL( string: "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/CIFAR10/cifar-10-binary.tar.gz")!, normalizing: true) } - /// Creates an instance with `batchSize` using `remoteBinaryArchiveLocation`. + /// Creates an instance with `batchSize` on `device` using `remoteBinaryArchiveLocation`. /// /// - Parameters: /// - entropy: a source of randomness used to shuffle sample ordering. It @@ -65,6 +66,7 @@ public struct CIFAR10 { public init( batchSize: Int, entropy: Entropy, + device: Device, remoteBinaryArchiveLocation: URL, localStorageDirectory: URL = DatasetUtilities.defaultDirectory .appendingPathComponent("CIFAR10", isDirectory: true), @@ -76,13 +78,13 @@ public struct CIFAR10 { let trainingSamples = loadCIFARTrainingFiles(in: localStorageDirectory) training = TrainingEpochs(samples: trainingSamples, batchSize: batchSize, entropy: entropy) .lazy.map { (batches: Batches) -> LazyMapSequence in - return batches.lazy.map{ makeBatch(samples: $0, normalizing: normalizing) } + return batches.lazy.map{ makeBatch(samples: $0, normalizing: normalizing, device: device) } } // Validation data let validationSamples = loadCIFARTestFile(in: localStorageDirectory) validation = validationSamples.inBatches(of: batchSize).lazy.map { - makeBatch(samples: $0, normalizing: normalizing) + makeBatch(samples: $0, normalizing: normalizing, device: device) } } } @@ -145,19 +147,20 @@ func loadCIFARTestFile(in localStorageDirectory: URL) -> [(data: [UInt8], label: return loadCIFARFile(named: "test_batch.bin", in: localStorageDirectory) } -func makeBatch(samples: BatchSamples, normalizing: Bool) -> LabeledImage -where BatchSamples.Element == (data: [UInt8], label: Int32) { +fileprivate func makeBatch( + samples: BatchSamples, normalizing: Bool, device: Device +) -> LabeledImage where BatchSamples.Element == (data: [UInt8], label: Int32) { let bytes = samples.lazy.map(\.data).reduce(into: [], +=) - let images = Tensor(shape: [samples.count, 3, 32, 32], scalars: bytes) + let images = Tensor(shape: [samples.count, 3, 32, 32], scalars: bytes, on: device) var imageTensor = Tensor(images.transposed(permutation: [0, 2, 3, 1])) imageTensor /= 255.0 if normalizing { - let mean = Tensor([0.4913996898, 0.4821584196, 0.4465309242]) - let std = Tensor([0.2470322324, 0.2434851280, 0.2615878417]) + let mean = Tensor([0.4913996898, 0.4821584196, 0.4465309242], on: device) + let std = Tensor([0.2470322324, 0.2434851280, 0.2615878417], on: device) imageTensor = (imageTensor - mean) / std } let labels = Tensor(samples.map(\.label)) return LabeledImage(data: imageTensor, label: labels) -} \ No newline at end of file +} diff --git a/Datasets/CMakeLists.txt b/Datasets/CMakeLists.txt index 661b09744b8..03be6f9dc53 100644 --- a/Datasets/CMakeLists.txt +++ b/Datasets/CMakeLists.txt @@ -15,6 +15,10 @@ add_library(Datasets MNIST/MNIST.swift MNIST/FashionMNIST.swift MNIST/KuzushijiMNIST.swift + MNIST/OldMNISTDatasetHandler.swift + MNIST/OldMNIST.swift + MNIST/OldFashionMNIST.swift + MNIST/OldKuzushijiMNIST.swift ObjectDetectionDataset.swift BostonHousing/BostonHousing.swift TextUnsupervised/TextUnsupervised.swift diff --git a/Datasets/MNIST/FashionMNIST.swift b/Datasets/MNIST/FashionMNIST.swift index b7809903508..8a63dfa912f 100644 --- a/Datasets/MNIST/FashionMNIST.swift +++ b/Datasets/MNIST/FashionMNIST.swift @@ -21,41 +21,79 @@ import Foundation import TensorFlow import Batcher -public struct FashionMNIST: ImageClassificationDataset { - public typealias SourceDataSet = [TensorPair] - public let training: Batcher - public let test: Batcher +public struct FashionMNIST { + /// Type of the collection of non-collated batches. + public typealias Batches = Slices>> + /// The type of the training data, represented as a sequence of epochs, which + /// are collection of batches. + public typealias Training = LazyMapSequence< + TrainingEpochs<[(data: [UInt8], label: Int32)], Entropy>, + LazyMapSequence + > + /// The type of the validation data, represented as a collection of batches. + public typealias Validation = LazyMapSequence, LabeledImage> + /// The training epochs. + public let training: Training + /// The validation batches. + public let validation: Validation - public init(batchSize: Int) { - self.init(batchSize: batchSize, flattening: false, normalizing: false) - } - - public init( - batchSize: Int, flattening: Bool = false, normalizing: Bool = false, - localStorageDirectory: URL = DatasetUtilities.defaultDirectory - .appendingPathComponent("FashionMNIST", isDirectory: true) - ) { - training = Batcher( - on: fetchMNISTDataset( - localStorageDirectory: localStorageDirectory, - remoteBaseDirectory: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/", - imagesFilename: "train-images-idx3-ubyte", - labelsFilename: "train-labels-idx1-ubyte", - flattening: flattening, - normalizing: normalizing), - batchSize: batchSize, - numWorkers: 1, //No need to use parallelism since everything is loaded in memory - shuffle: true) + /// Creates an instance with `batchSize`. + /// + /// - Parameter entropy: a source of randomness used to shuffle sample + /// ordering. It will be stored in `self`, so if it is only pseudorandom + /// and has value semantics, the sequence of epochs is deterministic and not + /// dependent on other operations. + public init(batchSize: Int, entropy: Entropy) { + self.init(batchSize: batchSize, device: Device.default, entropy: entropy, + flattening: false, normalizing: false) + } - test = Batcher( - on: fetchMNISTDataset( - localStorageDirectory: localStorageDirectory, - remoteBaseDirectory: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/", - imagesFilename: "t10k-images-idx3-ubyte", - labelsFilename: "t10k-labels-idx1-ubyte", - flattening: flattening, - normalizing: normalizing), - batchSize: batchSize, - numWorkers: 1) //No need to use parallelism since everything is loaded in memory + /// Creates an instance with `batchSize` on `device`. + /// + /// - Parameters: + /// - entropy: a source of randomness used to shuffle sample ordering. It + /// will be stored in `self`, so if it is only pseudorandom and has value + /// semantics, the sequence of epochs is deterministic and not dependent + /// on other operations. + /// - flattening: flattens the data to be a 2d-tensor iff `true. The default value + /// is `false`. + /// - normalizing: normalizes the batches to have values from -1.0 to 1.0 iff `true`. + /// The default value is `false`. + /// - localStorageDirectory: the directory in which the dataset is stored. + public init( + batchSize: Int, device: Device, entropy: Entropy, flattening: Bool = false, + normalizing: Bool = false, + localStorageDirectory: URL = DatasetUtilities.defaultDirectory + .appendingPathComponent("FashionMNIST", isDirectory: true) + ) { + training = TrainingEpochs( + samples: fetchMNISTDataset( + localStorageDirectory: localStorageDirectory, + remoteBaseDirectory: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/", + imagesFilename: "train-images-idx3-ubyte", + labelsFilename: "train-labels-idx1-ubyte"), + batchSize: batchSize, entropy: entropy + ).lazy.map { (batches: Batches) -> LazyMapSequence in + return batches.lazy.map{ makeMNISTBatch( + samples: $0, flattening: flattening, normalizing: normalizing, device: device + )} } + + validation = fetchMNISTDataset( + localStorageDirectory: localStorageDirectory, + remoteBaseDirectory: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/", + imagesFilename: "t10k-images-idx3-ubyte", + labelsFilename: "t10k-labels-idx1-ubyte" + ).inBatches(of: batchSize).lazy.map { + makeMNISTBatch(samples: $0, flattening: flattening, normalizing: normalizing, + device: device) + } + } } + +extension FashionMNIST: ImageClassificationData where Entropy == SystemRandomNumberGenerator { + /// Creates an instance with `batchSize`. + public init(batchSize: Int) { + self.init(batchSize: batchSize, entropy: SystemRandomNumberGenerator()) + } +} \ No newline at end of file diff --git a/Datasets/MNIST/KuzushijiMNIST.swift b/Datasets/MNIST/KuzushijiMNIST.swift index d0c8f82abba..551ddc6a995 100644 --- a/Datasets/MNIST/KuzushijiMNIST.swift +++ b/Datasets/MNIST/KuzushijiMNIST.swift @@ -20,41 +20,79 @@ import Foundation import TensorFlow import Batcher -public struct KuzushijiMNIST: ImageClassificationDataset { - public typealias SourceDataSet = [TensorPair] - public let training: Batcher - public let test: Batcher +public struct KuzushijiMNIST { + /// Type of the collection of non-collated batches. + public typealias Batches = Slices>> + /// The type of the training data, represented as a sequence of epochs, which + /// are collection of batches. + public typealias Training = LazyMapSequence< + TrainingEpochs<[(data: [UInt8], label: Int32)], Entropy>, + LazyMapSequence + > + /// The type of the validation data, represented as a collection of batches. + public typealias Validation = LazyMapSequence, LabeledImage> + /// The training epochs. + public let training: Training + /// The validation batches. + public let validation: Validation - public init(batchSize: Int) { - self.init(batchSize: batchSize, flattening: false, normalizing: false) - } - - public init( - batchSize: Int, flattening: Bool = false, normalizing: Bool = false, - localStorageDirectory: URL = DatasetUtilities.defaultDirectory - .appendingPathComponent("KuzushijiMNIST", isDirectory: true) - ) { - training = Batcher( - on: fetchMNISTDataset( - localStorageDirectory: localStorageDirectory, - remoteBaseDirectory: "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/KMNIST", - imagesFilename: "train-images-idx3-ubyte", - labelsFilename: "train-labels-idx1-ubyte", - flattening: flattening, - normalizing: normalizing), - batchSize: batchSize, - numWorkers: 1, //No need to use parallelism since everything is loaded in memory - shuffle: true) + /// Creates an instance with `batchSize`. + /// + /// - Parameter entropy: a source of randomness used to shuffle sample + /// ordering. It will be stored in `self`, so if it is only pseudorandom + /// and has value semantics, the sequence of epochs is deterministic and not + /// dependent on other operations. + public init(batchSize: Int, entropy: Entropy) { + self.init(batchSize: batchSize, device: Device.default, entropy: entropy, + flattening: false, normalizing: false) + } - test = Batcher( - on: fetchMNISTDataset( - localStorageDirectory: localStorageDirectory, - remoteBaseDirectory: "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/KMNIST", - imagesFilename: "t10k-images-idx3-ubyte", - labelsFilename: "t10k-labels-idx1-ubyte", - flattening: flattening, - normalizing: normalizing), - batchSize: batchSize, - numWorkers: 1) //No need to use parallelism since everything is loaded in memory + /// Creates an instance with `batchSize` on `device`. + /// + /// - Parameters: + /// - entropy: a source of randomness used to shuffle sample ordering. It + /// will be stored in `self`, so if it is only pseudorandom and has value + /// semantics, the sequence of epochs is deterministic and not dependent + /// on other operations. + /// - flattening: flattens the data to be a 2d-tensor iff `true. The default value + /// is `false`. + /// - normalizing: normalizes the batches to have values from -1.0 to 1.0 iff `true`. + /// The default value is `false`. + /// - localStorageDirectory: the directory in which the dataset is stored. + public init( + batchSize: Int, device: Device, entropy: Entropy, flattening: Bool = false, + normalizing: Bool = false, + localStorageDirectory: URL = DatasetUtilities.defaultDirectory + .appendingPathComponent("KuzushijiMNIST", isDirectory: true) + ) { + training = TrainingEpochs( + samples: fetchMNISTDataset( + localStorageDirectory: localStorageDirectory, + remoteBaseDirectory: "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/KMNIST", + imagesFilename: "train-images-idx3-ubyte", + labelsFilename: "train-labels-idx1-ubyte"), + batchSize: batchSize, entropy: entropy + ).lazy.map { (batches: Batches) -> LazyMapSequence in + return batches.lazy.map{ makeMNISTBatch( + samples: $0, flattening: flattening, normalizing: normalizing, device: device + )} } + + validation = fetchMNISTDataset( + localStorageDirectory: localStorageDirectory, + remoteBaseDirectory: "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/KMNIST", + imagesFilename: "t10k-images-idx3-ubyte", + labelsFilename: "t10k-labels-idx1-ubyte" + ).inBatches(of: batchSize).lazy.map { + makeMNISTBatch(samples: $0, flattening: flattening, normalizing: normalizing, + device: device) + } + } } + +extension KuzushijiMNIST: ImageClassificationData where Entropy == SystemRandomNumberGenerator { + /// Creates an instance with `batchSize`. + public init(batchSize: Int) { + self.init(batchSize: batchSize, entropy: SystemRandomNumberGenerator()) + } +} \ No newline at end of file diff --git a/Datasets/MNIST/MNIST.swift b/Datasets/MNIST/MNIST.swift index 4f4414ab028..3faa5be9f07 100644 --- a/Datasets/MNIST/MNIST.swift +++ b/Datasets/MNIST/MNIST.swift @@ -21,41 +21,79 @@ import Foundation import TensorFlow import Batcher -public struct MNIST: ImageClassificationDataset { - public typealias SourceDataSet = [TensorPair] - public let training: Batcher - public let test: Batcher +public struct MNIST { + /// Type of the collection of non-collated batches. + public typealias Batches = Slices>> + /// The type of the training data, represented as a sequence of epochs, which + /// are collection of batches. + public typealias Training = LazyMapSequence< + TrainingEpochs<[(data: [UInt8], label: Int32)], Entropy>, + LazyMapSequence + > + /// The type of the validation data, represented as a collection of batches. + public typealias Validation = LazyMapSequence, LabeledImage> + /// The training epochs. + public let training: Training + /// The validation batches. + public let validation: Validation - public init(batchSize: Int) { - self.init(batchSize: batchSize, flattening: false, normalizing: false) - } + /// Creates an instance with `batchSize`. + /// + /// - Parameter entropy: a source of randomness used to shuffle sample + /// ordering. It will be stored in `self`, so if it is only pseudorandom + /// and has value semantics, the sequence of epochs is deterministic and not + /// dependent on other operations. + public init(batchSize: Int, entropy: Entropy) { + self.init(batchSize: batchSize, device: Device.default, entropy: entropy, + flattening: false, normalizing: false) + } - public init( - batchSize: Int, flattening: Bool = false, normalizing: Bool = false, - localStorageDirectory: URL = DatasetUtilities.defaultDirectory - .appendingPathComponent("MNIST", isDirectory: true) - ) { - training = Batcher( - on: fetchMNISTDataset( - localStorageDirectory: localStorageDirectory, - remoteBaseDirectory: "https://storage.googleapis.com/cvdf-datasets/mnist", - imagesFilename: "train-images-idx3-ubyte", - labelsFilename: "train-labels-idx1-ubyte", - flattening: flattening, - normalizing: normalizing), - batchSize: batchSize, - numWorkers: 1, //No need to use parallelism since everything is loaded in memory - shuffle: true) - - test = Batcher( - on: fetchMNISTDataset( - localStorageDirectory: localStorageDirectory, - remoteBaseDirectory: "https://storage.googleapis.com/cvdf-datasets/mnist", - imagesFilename: "t10k-images-idx3-ubyte", - labelsFilename: "t10k-labels-idx1-ubyte", - flattening: flattening, - normalizing: normalizing), - batchSize: batchSize, - numWorkers: 1) //No need to use parallelism since everything is loaded in memory + /// Creates an instance with `batchSize` on `device`. + /// + /// - Parameters: + /// - entropy: a source of randomness used to shuffle sample ordering. It + /// will be stored in `self`, so if it is only pseudorandom and has value + /// semantics, the sequence of epochs is deterministic and not dependent + /// on other operations. + /// - flattening: flattens the data to be a 2d-tensor iff `true. The default value + /// is `false`. + /// - normalizing: normalizes the batches to have values from -1.0 to 1.0 iff `true`. + /// The default value is `false`. + /// - localStorageDirectory: the directory in which the dataset is stored. + public init( + batchSize: Int, device: Device, entropy: Entropy, flattening: Bool = false, + normalizing: Bool = false, + localStorageDirectory: URL = DatasetUtilities.defaultDirectory + .appendingPathComponent("MNIST", isDirectory: true) + ) { + training = TrainingEpochs( + samples: fetchMNISTDataset( + localStorageDirectory: localStorageDirectory, + remoteBaseDirectory: "https://storage.googleapis.com/cvdf-datasets/mnist", + imagesFilename: "train-images-idx3-ubyte", + labelsFilename: "train-labels-idx1-ubyte"), + batchSize: batchSize, entropy: entropy + ).lazy.map { (batches: Batches) -> LazyMapSequence in + return batches.lazy.map{ makeMNISTBatch( + samples: $0, flattening: flattening, normalizing: normalizing, device: device + )} + } + + validation = fetchMNISTDataset( + localStorageDirectory: localStorageDirectory, + remoteBaseDirectory: "https://storage.googleapis.com/cvdf-datasets/mnist", + imagesFilename: "t10k-images-idx3-ubyte", + labelsFilename: "t10k-labels-idx1-ubyte" + ).inBatches(of: batchSize).lazy.map { + makeMNISTBatch(samples: $0, flattening: flattening, normalizing: normalizing, + device: device) } + } +} + +extension MNIST: ImageClassificationData where Entropy == SystemRandomNumberGenerator { + /// Creates an instance with `batchSize`. + public init(batchSize: Int) { + self.init(batchSize: batchSize, entropy: SystemRandomNumberGenerator()) + } } diff --git a/Datasets/MNIST/MNISTDatasetHandler.swift b/Datasets/MNIST/MNISTDatasetHandler.swift index 02a03c8e431..9fcedcfd261 100644 --- a/Datasets/MNIST/MNISTDatasetHandler.swift +++ b/Datasets/MNIST/MNISTDatasetHandler.swift @@ -16,53 +16,53 @@ import Foundation import TensorFlow func fetchMNISTDataset( - localStorageDirectory: URL, - remoteBaseDirectory: String, - imagesFilename: String, - labelsFilename: String, - flattening: Bool, - normalizing: Bool -) -> [TensorPair] { - guard let remoteRoot = URL(string: remoteBaseDirectory) else { - fatalError("Failed to create MNIST root url: \(remoteBaseDirectory)") - } + localStorageDirectory: URL, + remoteBaseDirectory: String, + imagesFilename: String, + labelsFilename: String +) -> [(data: [UInt8], label: Int32)] { + guard let remoteRoot = URL(string: remoteBaseDirectory) else { + fatalError("Failed to create MNIST root url: \(remoteBaseDirectory)") + } - let imagesData = DatasetUtilities.fetchResource( - filename: imagesFilename, - fileExtension: "gz", - remoteRoot: remoteRoot, - localStorageDirectory: localStorageDirectory) - let labelsData = DatasetUtilities.fetchResource( - filename: labelsFilename, - fileExtension: "gz", - remoteRoot: remoteRoot, - localStorageDirectory: localStorageDirectory) + let imagesData = DatasetUtilities.fetchResource( + filename: imagesFilename, + fileExtension: "gz", + remoteRoot: remoteRoot, + localStorageDirectory: localStorageDirectory) + let labelsData = DatasetUtilities.fetchResource( + filename: labelsFilename, + fileExtension: "gz", + remoteRoot: remoteRoot, + localStorageDirectory: localStorageDirectory) - let images = [UInt8](imagesData).dropFirst(16).map(Float.init) - let labels = [UInt8](labelsData).dropFirst(8).map(Int32.init) + let images = [UInt8](imagesData).dropFirst(16) + let labels = [UInt8](labelsData).dropFirst(8).map(Int32.init) + + var labeledImages: [(data: [UInt8], label: Int32)] = [] - let rowCount = labels.count - let (imageWidth, imageHeight) = (28, 28) + let imageByteSize = 28 * 28 + for imageIndex in 0..(labels[$0])) - } - } else { - var images = - Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images) - .transposed(permutation: [0, 2, 3, 1]) / 255.0 - if normalizing { - images = images * 2.0 - 1.0 - } - return (0..(labels[$0])) - } - } + return labeledImages +} + +func makeMNISTBatch( + samples: BatchSamples, flattening: Bool, normalizing: Bool, device:Device +) -> LabeledImage where BatchSamples.Element == (data: [UInt8], label: Int32) { + let bytes = samples.lazy.map(\.data).reduce(into: [], +=) + let shape: TensorShape = flattening ? [samples.count, 28 * 28] : [samples.count, 28, 28, 1] + let images = Tensor(shape: shape, scalars: bytes, on:device) + + var imageTensor = Tensor(images) / 255.0 + if normalizing { + imageTensor = imageTensor * 2.0 - 1.0 + } + + let labels = Tensor(samples.map(\.label)) + return LabeledImage(data: imageTensor, label: labels) } diff --git a/Datasets/MNIST/OldFashionMNIST.swift b/Datasets/MNIST/OldFashionMNIST.swift new file mode 100644 index 00000000000..2fbc3b16ad1 --- /dev/null +++ b/Datasets/MNIST/OldFashionMNIST.swift @@ -0,0 +1,59 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// Original source: +// "Fashion-MNIST: a Novel Image Dataset for Benchmarking Machine Learning Algorithms" +// Han Xiao and Kashif Rasul and Roland Vollgraf +// https://arxiv.org/abs/1708.07747 +import Foundation +import TensorFlow +import Batcher + +public struct OldFashionMNIST: ImageClassificationDataset { + public typealias SourceDataSet = [TensorPair] + public let training: Batcher + public let test: Batcher + + public init(batchSize: Int) { + self.init(batchSize: batchSize, flattening: false, normalizing: false) + } + + public init( + batchSize: Int, flattening: Bool = false, normalizing: Bool = false, + localStorageDirectory: URL = DatasetUtilities.defaultDirectory + .appendingPathComponent("FashionMNIST", isDirectory: true) + ) { + training = Batcher( + on: oldFetchMNISTDataset( + localStorageDirectory: localStorageDirectory, + remoteBaseDirectory: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/", + imagesFilename: "train-images-idx3-ubyte", + labelsFilename: "train-labels-idx1-ubyte", + flattening: flattening, + normalizing: normalizing), + batchSize: batchSize, + numWorkers: 1, //No need to use parallelism since everything is loaded in memory + shuffle: true) + + test = Batcher( + on: oldFetchMNISTDataset( + localStorageDirectory: localStorageDirectory, + remoteBaseDirectory: "http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/", + imagesFilename: "t10k-images-idx3-ubyte", + labelsFilename: "t10k-labels-idx1-ubyte", + flattening: flattening, + normalizing: normalizing), + batchSize: batchSize, + numWorkers: 1) //No need to use parallelism since everything is loaded in memory + } +} \ No newline at end of file diff --git a/Datasets/MNIST/OldKuzushijiMNIST.swift b/Datasets/MNIST/OldKuzushijiMNIST.swift new file mode 100644 index 00000000000..5d767f1d98d --- /dev/null +++ b/Datasets/MNIST/OldKuzushijiMNIST.swift @@ -0,0 +1,58 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// Original source: +// "KMNIST Dataset" (created by CODH), https://arxiv.org/abs/1812.01718 +// adapted from "Kuzushiji Dataset" (created by NIJL and others), doi:10.20676/00000341 +import Foundation +import TensorFlow +import Batcher + +public struct OldKuzushijiMNIST: ImageClassificationDataset { + public typealias SourceDataSet = [TensorPair] + public let training: Batcher + public let test: Batcher + + public init(batchSize: Int) { + self.init(batchSize: batchSize, flattening: false, normalizing: false) + } + + public init( + batchSize: Int, flattening: Bool = false, normalizing: Bool = false, + localStorageDirectory: URL = DatasetUtilities.defaultDirectory + .appendingPathComponent("KuzushijiMNIST", isDirectory: true) + ) { + training = Batcher( + on: oldFetchMNISTDataset( + localStorageDirectory: localStorageDirectory, + remoteBaseDirectory: "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/KMNIST", + imagesFilename: "train-images-idx3-ubyte", + labelsFilename: "train-labels-idx1-ubyte", + flattening: flattening, + normalizing: normalizing), + batchSize: batchSize, + numWorkers: 1, //No need to use parallelism since everything is loaded in memory + shuffle: true) + + test = Batcher( + on: oldFetchMNISTDataset( + localStorageDirectory: localStorageDirectory, + remoteBaseDirectory: "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/KMNIST", + imagesFilename: "t10k-images-idx3-ubyte", + labelsFilename: "t10k-labels-idx1-ubyte", + flattening: flattening, + normalizing: normalizing), + batchSize: batchSize, + numWorkers: 1) //No need to use parallelism since everything is loaded in memory + } +} \ No newline at end of file diff --git a/Datasets/MNIST/OldMNIST.swift b/Datasets/MNIST/OldMNIST.swift new file mode 100644 index 00000000000..b42391fa438 --- /dev/null +++ b/Datasets/MNIST/OldMNIST.swift @@ -0,0 +1,59 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// Original source: +// "The MNIST database of handwritten digits" +// Yann LeCun, Corinna Cortes, and Christopher J.C. Burges +// http://yann.lecun.com/exdb/mnist/ +import Foundation +import TensorFlow +import Batcher + +public struct OldMNIST: ImageClassificationDataset { + public typealias SourceDataSet = [TensorPair] + public let training: Batcher + public let test: Batcher + + public init(batchSize: Int) { + self.init(batchSize: batchSize, flattening: false, normalizing: false) + } + + public init( + batchSize: Int, flattening: Bool = false, normalizing: Bool = false, + localStorageDirectory: URL = DatasetUtilities.defaultDirectory + .appendingPathComponent("MNIST", isDirectory: true) + ) { + training = Batcher( + on: oldFetchMNISTDataset( + localStorageDirectory: localStorageDirectory, + remoteBaseDirectory: "https://storage.googleapis.com/cvdf-datasets/mnist", + imagesFilename: "train-images-idx3-ubyte", + labelsFilename: "train-labels-idx1-ubyte", + flattening: flattening, + normalizing: normalizing), + batchSize: batchSize, + numWorkers: 1, //No need to use parallelism since everything is loaded in memory + shuffle: true) + + test = Batcher( + on: oldFetchMNISTDataset( + localStorageDirectory: localStorageDirectory, + remoteBaseDirectory: "https://storage.googleapis.com/cvdf-datasets/mnist", + imagesFilename: "t10k-images-idx3-ubyte", + labelsFilename: "t10k-labels-idx1-ubyte", + flattening: flattening, + normalizing: normalizing), + batchSize: batchSize, + numWorkers: 1) //No need to use parallelism since everything is loaded in memory + } +} \ No newline at end of file diff --git a/Datasets/MNIST/OldMNISTDatasetHandler.swift b/Datasets/MNIST/OldMNISTDatasetHandler.swift new file mode 100644 index 00000000000..6f8e87fcbf9 --- /dev/null +++ b/Datasets/MNIST/OldMNISTDatasetHandler.swift @@ -0,0 +1,67 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +import Foundation +import TensorFlow + +func oldFetchMNISTDataset( + localStorageDirectory: URL, + remoteBaseDirectory: String, + imagesFilename: String, + labelsFilename: String, + flattening: Bool, + normalizing: Bool +) -> [TensorPair] { + guard let remoteRoot = URL(string: remoteBaseDirectory) else { + fatalError("Failed to create MNIST root url: \(remoteBaseDirectory)") + } + + let imagesData = DatasetUtilities.fetchResource( + filename: imagesFilename, + fileExtension: "gz", + remoteRoot: remoteRoot, + localStorageDirectory: localStorageDirectory) + let labelsData = DatasetUtilities.fetchResource( + filename: labelsFilename, + fileExtension: "gz", + remoteRoot: remoteRoot, + localStorageDirectory: localStorageDirectory) + + let images = [UInt8](imagesData).dropFirst(16).map(Float.init) + let labels = [UInt8](labelsData).dropFirst(8).map(Int32.init) + + let rowCount = labels.count + let (imageWidth, imageHeight) = (28, 28) + + if flattening { + var flattenedImages = + Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images) + / 255.0 + if normalizing { + flattenedImages = flattenedImages * 2.0 - 1.0 + } + return (0..(labels[$0])) + } + } else { + var images = + Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images) + .transposed(permutation: [0, 2, 3, 1]) / 255.0 + if normalizing { + images = images * 2.0 - 1.0 + } + return (0..(labels[$0])) + } + } +} \ No newline at end of file diff --git a/Examples/LeNet-MNIST/main.swift b/Examples/LeNet-MNIST/main.swift index 7404bd095c5..e5658701466 100644 --- a/Examples/LeNet-MNIST/main.swift +++ b/Examples/LeNet-MNIST/main.swift @@ -43,20 +43,20 @@ struct Statistics { } // The training loop. -for epoch in 1...epochCount { +for (epoch, epochBatches) in dataset.training.prefix(epochCount).enumerated() { var trainStats = Statistics() var testStats = Statistics() Context.local.learningPhase = .training - for batch in dataset.training.sequenced() { - let (images, labels) = (batch.first, batch.second) + for batch in epochBatches { + let (images, labels) = (batch.data, batch.label) // Compute the gradient with respect to the model. let 𝛁model = TensorFlow.gradient(at: classifier) { classifier -> Tensor in let ŷ = classifier(images) let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== labels trainStats.correctGuessCount += Int( Tensor(correctPredictions).sum().scalarized()) - trainStats.totalGuessCount += batch.first.shape[0] + trainStats.totalGuessCount += images.shape[0] let loss = softmaxCrossEntropy(logits: ŷ, labels: labels) trainStats.totalLoss += loss.scalarized() trainStats.batches += 1 @@ -67,13 +67,13 @@ for epoch in 1...epochCount { } Context.local.learningPhase = .inference - for batch in dataset.test.sequenced() { - let (images, labels) = (batch.first, batch.second) + for batch in dataset.validation { + let (images, labels) = (batch.data, batch.label) // Compute loss on test set let ŷ = classifier(images) let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== labels testStats.correctGuessCount += Int(Tensor(correctPredictions).sum().scalarized()) - testStats.totalGuessCount += batch.first.shape[0] + testStats.totalGuessCount += images.shape[0] let loss = softmaxCrossEntropy(logits: ŷ, labels: labels) testStats.totalLoss += loss.scalarized() testStats.batches += 1 @@ -83,7 +83,7 @@ for epoch in 1...epochCount { let testAccuracy = Float(testStats.correctGuessCount) / Float(testStats.totalGuessCount) print( """ - [Epoch \(epoch)] \ + [Epoch \(epoch + 1)] \ Training Loss: \(trainStats.totalLoss / Float(trainStats.batches)), \ Training Accuracy: \(trainStats.correctGuessCount)/\(trainStats.totalGuessCount) \ (\(trainAccuracy)), \ @@ -91,4 +91,4 @@ for epoch in 1...epochCount { Test Accuracy: \(testStats.correctGuessCount)/\(testStats.totalGuessCount) \ (\(testAccuracy)) """) -} +} \ No newline at end of file diff --git a/GAN/main.swift b/GAN/main.swift index 9702cf3c4bc..70280867782 100644 --- a/GAN/main.swift +++ b/GAN/main.swift @@ -105,7 +105,7 @@ func sampleVector(size: Int) -> Tensor { Tensor(randomNormal: [size, latentSize]) } -let dataset = MNIST(batchSize: batchSize, flattening: true, normalizing: true) +let dataset = OldMNIST(batchSize: batchSize, flattening: true, normalizing: true) var generator = Generator() var discriminator = Discriminator() diff --git a/Tests/DatasetsTests/CIFAR10/CIFAR10Tests.swift b/Tests/DatasetsTests/CIFAR10/CIFAR10Tests.swift index ec1c15f01bb..94d2c423c28 100644 --- a/Tests/DatasetsTests/CIFAR10/CIFAR10Tests.swift +++ b/Tests/DatasetsTests/CIFAR10/CIFAR10Tests.swift @@ -8,6 +8,7 @@ final class CIFAR10Tests: XCTestCase { let dataset = CIFAR10( batchSize: 1, entropy: SystemRandomNumberGenerator(), + device: Device.default, remoteBinaryArchiveLocation: URL( string: @@ -33,6 +34,7 @@ final class CIFAR10Tests: XCTestCase { let dataset = CIFAR10( batchSize: 50000, entropy: SystemRandomNumberGenerator(), + device: Device.default, remoteBinaryArchiveLocation: URL( string: diff --git a/Tests/DatasetsTests/MNIST/MNISTTests.swift b/Tests/DatasetsTests/MNIST/MNISTTests.swift index 47944cde64a..6daa0d98b0f 100644 --- a/Tests/DatasetsTests/MNIST/MNISTTests.swift +++ b/Tests/DatasetsTests/MNIST/MNISTTests.swift @@ -7,22 +7,26 @@ final class MNISTTests: XCTestCase { let dataset = MNIST(batchSize: 1) var totalCount = 0 - for example in dataset.training.sequenced() { - XCTAssertTrue((0..<10).contains(example.second[0].scalar!)) - XCTAssertEqual(example.first.shape, [1, 28, 28, 1]) - totalCount += 1 + for epochBatches in dataset.training.prefix(1){ + for batch in epochBatches { + XCTAssertTrue((0..<10).contains(batch.label[0].scalar!)) + XCTAssertEqual(batch.data.shape, [1, 28, 28, 1]) + totalCount += 1 + } } XCTAssertEqual(totalCount, 60000) } - + func testCreateFashionMNIST() { let dataset = FashionMNIST(batchSize: 1) var totalCount = 0 - for example in dataset.training.sequenced() { - XCTAssertTrue((0..<10).contains(example.second[0].scalar!)) - XCTAssertEqual(example.first.shape, [1, 28, 28, 1]) - totalCount += 1 + for epochBatches in dataset.training.prefix(1){ + for batch in epochBatches { + XCTAssertTrue((0..<10).contains(batch.label[0].scalar!)) + XCTAssertEqual(batch.data.shape, [1, 28, 28, 1]) + totalCount += 1 + } } XCTAssertEqual(totalCount, 60000) } @@ -31,10 +35,12 @@ final class MNISTTests: XCTestCase { let dataset = KuzushijiMNIST(batchSize: 1) var totalCount = 0 - for example in dataset.training.sequenced() { - XCTAssertTrue((0..<10).contains(example.second[0].scalar!)) - XCTAssertEqual(example.first.shape, [1, 28, 28, 1]) - totalCount += 1 + for epochBatches in dataset.training.prefix(1){ + for batch in epochBatches { + XCTAssertTrue((0..<10).contains(batch.label[0].scalar!)) + XCTAssertEqual(batch.data.shape, [1, 28, 28, 1]) + totalCount += 1 + } } XCTAssertEqual(totalCount, 60000) }