diff --git a/Benchmarks/Models/ResNetCIFAR10.swift b/Benchmarks/Models/ResNetCIFAR10.swift index 04004707494..e7430ea2940 100755 --- a/Benchmarks/Models/ResNetCIFAR10.swift +++ b/Benchmarks/Models/ResNetCIFAR10.swift @@ -38,11 +38,11 @@ enum ResNetCIFAR10: BenchmarkModel { } static func makeInferenceBenchmark(settings: BenchmarkSettings) -> Benchmark { - return ImageClassificationInference(settings: settings) + return ImageClassificationInference(settings: settings) } static func makeTrainingBenchmark(settings: BenchmarkSettings) -> Benchmark { - return ImageClassificationTraining(settings: settings) + return ImageClassificationTraining(settings: settings) } } diff --git a/Datasets/CIFAR10/CIFAR10.swift b/Datasets/CIFAR10/CIFAR10.swift index e408c1164a0..b806d866e21 100644 --- a/Datasets/CIFAR10/CIFAR10.swift +++ b/Datasets/CIFAR10/CIFAR10.swift @@ -22,120 +22,142 @@ import ModelSupport import TensorFlow import Batcher -public struct CIFAR10: ImageClassificationDataset { - public typealias SourceDataSet = [TensorPair] - public let training: Batcher - public let test: Batcher - - public init(batchSize: Int) { - self.init( - batchSize: batchSize, - remoteBinaryArchiveLocation: URL( - string: "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/CIFAR10/cifar-10-binary.tar.gz")!, - normalizing: true) +public struct CIFAR10 { + /// Type of the collection of non-collated batches. + public typealias Batches = Slices>> + /// The type of the training data, represented as a sequence of epochs, which + /// are collection of batches. + public typealias Training = LazyMapSequence< + TrainingEpochs<[(data: [UInt8], label: Int32)], Entropy>, + LazyMapSequence + > + /// The type of the validation data, represented as a collection of batches. + public typealias Validation = LazyMapSequence, LabeledImage> + /// The training epochs. + public let training: Training + /// The validation batches. + public let validation: Validation + + /// Creates an instance with `batchSize`. + /// + /// - Parameter entropy: a source of randomness used to shuffle sample + /// ordering. It will be stored in `self`, so if it is only pseudorandom + /// and has value semantics, the sequence of epochs is deterministic and not + /// dependent on other operations. + public init(batchSize: Int, entropy: Entropy) { + self.init( + batchSize: batchSize, + entropy: entropy, + remoteBinaryArchiveLocation: URL( + string: "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/CIFAR10/cifar-10-binary.tar.gz")!, + normalizing: true) + } + + /// Creates an instance with `batchSize` using `remoteBinaryArchiveLocation`. + /// + /// - Parameters: + /// - entropy: a source of randomness used to shuffle sample ordering. It + /// will be stored in `self`, so if it is only pseudorandom and has value + /// semantics, the sequence of epochs is deterministic and not dependent + /// on other operations. + /// - normalizing: normalizes the batches with the mean and standard deviation + /// of the dataset iff `true`. Default value is `true`. + public init( + batchSize: Int, + entropy: Entropy, + remoteBinaryArchiveLocation: URL, + localStorageDirectory: URL = DatasetUtilities.defaultDirectory + .appendingPathComponent("CIFAR10", isDirectory: true), + normalizing: Bool + ){ + downloadCIFAR10IfNotPresent(from: remoteBinaryArchiveLocation, to: localStorageDirectory) + + // Training data + let trainingSamples = loadCIFARTrainingFiles(in: localStorageDirectory) + training = TrainingEpochs(samples: trainingSamples, batchSize: batchSize, entropy: entropy) + .lazy.map { (batches: Batches) -> LazyMapSequence in + return batches.lazy.map{ makeBatch(samples: $0, normalizing: normalizing) } + } + + // Validation data + let validationSamples = loadCIFARTestFile(in: localStorageDirectory) + validation = validationSamples.inBatches(of: batchSize).lazy.map { + makeBatch(samples: $0, normalizing: normalizing) } + } +} - public init( - batchSize: Int, - remoteBinaryArchiveLocation: URL, - localStorageDirectory: URL = DatasetUtilities.defaultDirectory - .appendingPathComponent("CIFAR10", isDirectory: true), - normalizing: Bool) - { - downloadCIFAR10IfNotPresent(from: remoteBinaryArchiveLocation, to: localStorageDirectory) - self.training = Batcher( - on: loadCIFARTrainingFiles(localStorageDirectory: localStorageDirectory, normalizing: normalizing), - batchSize: batchSize, - numWorkers: 1, //No need to use parallelism since everything is loaded in memory - shuffle: true) - self.test = Batcher( - on: loadCIFARTestFile(localStorageDirectory: localStorageDirectory, normalizing: normalizing), - batchSize: batchSize, - numWorkers: 1) //No need to use parallelism since everything is loaded in memory - } +extension CIFAR10: ImageClassificationData where Entropy == SystemRandomNumberGenerator { + /// Creates an instance with `batchSize`. + public init(batchSize: Int) { + self.init(batchSize: batchSize, entropy: SystemRandomNumberGenerator()) + } } func downloadCIFAR10IfNotPresent(from location: URL, to directory: URL) { - let downloadPath = directory.appendingPathComponent("cifar-10-batches-bin").path - let directoryExists = FileManager.default.fileExists(atPath: downloadPath) - let contentsOfDir = try? FileManager.default.contentsOfDirectory(atPath: downloadPath) - let directoryEmpty = (contentsOfDir == nil) || (contentsOfDir!.isEmpty) + let downloadPath = directory.appendingPathComponent("cifar-10-batches-bin").path + let directoryExists = FileManager.default.fileExists(atPath: downloadPath) + let contentsOfDir = try? FileManager.default.contentsOfDirectory(atPath: downloadPath) + let directoryEmpty = (contentsOfDir == nil) || (contentsOfDir!.isEmpty) - guard !directoryExists || directoryEmpty else { return } + guard !directoryExists || directoryEmpty else { return } - let _ = DatasetUtilities.downloadResource( - filename: "cifar-10-binary", fileExtension: "tar.gz", - remoteRoot: location.deletingLastPathComponent(), localStorageDirectory: directory) + let _ = DatasetUtilities.downloadResource( + filename: "cifar-10-binary", fileExtension: "tar.gz", + remoteRoot: location.deletingLastPathComponent(), localStorageDirectory: directory) } -func loadCIFARFile(named name: String, in directory: URL, normalizing: Bool = true) -> [TensorPair] { - let path = directory.appendingPathComponent("cifar-10-batches-bin/\(name)").path - - let imageCount = 10000 - guard let fileContents = try? Data(contentsOf: URL(fileURLWithPath: path)) else { - printError("Could not read dataset file: \(name)") - exit(-1) - } - guard fileContents.count == 30_730_000 else { - printError( - "Dataset file \(name) should have 30730000 bytes, instead had \(fileContents.count)") - exit(-1) - } - - var bytes: [UInt8] = [] - var labels: [Int64] = [] - - let imageByteSize = 3073 - for imageIndex in 0..(shape: [imageCount], scalars: labels) - let images = Tensor(shape: [imageCount, 3, 32, 32], scalars: bytes) - - // Transpose from the CIFAR-provided N(CHW) to TF's default NHWC. - var imageTensor = Tensor(images.transposed(permutation: [0, 2, 3, 1])) - - // The value of mean and std were calculated with the following Swift code: - // ``` - // import TensorFlow - // import Datasets - // import Foundation - // let urlString = "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/CIFAR10/cifar-10-binary.tar.gz" - // let cifar = CIFAR10(batchSize: 50000, - // remoteBinaryArchiveLocation: URL(string: urlString)!, - // normalizing: false) - // for batch in cifar.training.sequenced() { - // let images = Tensor(batch.first) / 255.0 - // let mom = images.moments(squeezingAxes: [0,1,2]) - // print("mean: \(mom.mean) std: \(sqrt(mom.variance))") - // } - // ``` - if normalizing { - let mean = Tensor( - [0.4913996898, - 0.4821584196, - 0.4465309242]) - let std = Tensor( - [0.2470322324, - 0.2434851280, - 0.2615878417]) - imageTensor = ((imageTensor / 255.0) - mean) / std - } - - return (0..(labelTensor[$0])) } - +func loadCIFARFile(named name: String, in directory: URL) -> [(data: [UInt8], label: Int32)] { + let path = directory.appendingPathComponent("cifar-10-batches-bin/\(name)").path + + let imageCount = 10000 + guard let fileContents = try? Data(contentsOf: URL(fileURLWithPath: path)) else { + printError("Could not read dataset file: \(name)") + exit(-1) + } + guard fileContents.count == 30_730_000 else { + printError( + "Dataset file \(name) should have 30730000 bytes, instead had \(fileContents.count)") + exit(-1) + } + + var labeledImages: [(data: [UInt8], label: Int32)] = [] + + let imageByteSize = 3073 + for imageIndex in 0.. [TensorPair] { - let data = (1..<6).map { - loadCIFARFile(named: "data_batch_\($0).bin", in: localStorageDirectory, normalizing: normalizing) - } - return data.reduce([], +) +func loadCIFARTrainingFiles(in localStorageDirectory: URL) -> [(data: [UInt8], label: Int32)] { + let data = (1..<6).map { + loadCIFARFile(named: "data_batch_\($0).bin", in: localStorageDirectory) + } + return data.reduce([], +) } -func loadCIFARTestFile(localStorageDirectory: URL, normalizing: Bool = true) -> [TensorPair] { - return loadCIFARFile(named: "test_batch.bin", in: localStorageDirectory, normalizing: normalizing) +func loadCIFARTestFile(in localStorageDirectory: URL) -> [(data: [UInt8], label: Int32)] { + return loadCIFARFile(named: "test_batch.bin", in: localStorageDirectory) } + +func makeBatch(samples: BatchSamples, normalizing: Bool) -> LabeledImage +where BatchSamples.Element == (data: [UInt8], label: Int32) { + let bytes = samples.lazy.map(\.data).reduce(into: [], +=) + let images = Tensor(shape: [samples.count, 3, 32, 32], scalars: bytes) + + var imageTensor = Tensor(images.transposed(permutation: [0, 2, 3, 1])) + imageTensor /= 255.0 + if normalizing { + let mean = Tensor([0.4913996898, 0.4821584196, 0.4465309242]) + let std = Tensor([0.2470322324, 0.2434851280, 0.2615878417]) + imageTensor = (imageTensor - mean) / std + } + + let labels = Tensor(samples.map(\.label)) + return LabeledImage(data: imageTensor, label: labels) +} \ No newline at end of file diff --git a/Datasets/CIFAR10/OldCIFAR10.swift b/Datasets/CIFAR10/OldCIFAR10.swift new file mode 100644 index 00000000000..887ca6581b6 --- /dev/null +++ b/Datasets/CIFAR10/OldCIFAR10.swift @@ -0,0 +1,139 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// Original source: +// "The CIFAR-10 dataset" +// Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. +// https://www.cs.toronto.edu/~kriz/cifar.html +import Foundation +import ModelSupport +import TensorFlow +import Batcher + +public struct OldCIFAR10: ImageClassificationDataset { + public typealias SourceDataSet = [TensorPair] + public let training: Batcher + public let test: Batcher + + public init(batchSize: Int) { + self.init( + batchSize: batchSize, + remoteBinaryArchiveLocation: URL( + string: "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/CIFAR10/cifar-10-binary.tar.gz")!, + normalizing: true) + } + + public init( + batchSize: Int, + remoteBinaryArchiveLocation: URL, + localStorageDirectory: URL = DatasetUtilities.defaultDirectory + .appendingPathComponent("CIFAR10", isDirectory: true), + normalizing: Bool) + { + _downloadCIFAR10IfNotPresent(from: remoteBinaryArchiveLocation, to: localStorageDirectory) + self.training = Batcher( + on: _loadCIFARTrainingFiles(localStorageDirectory: localStorageDirectory, normalizing: normalizing), + batchSize: batchSize, + numWorkers: 1, //No need to use parallelism since everything is loaded in memory + shuffle: true) + self.test = Batcher( + on: _loadCIFARTestFile(localStorageDirectory: localStorageDirectory, normalizing: normalizing), + batchSize: batchSize, + numWorkers: 1) //No need to use parallelism since everything is loaded in memory + } +} + +func _downloadCIFAR10IfNotPresent(from location: URL, to directory: URL) { + let downloadPath = directory.appendingPathComponent("cifar-10-batches-bin").path + let directoryExists = FileManager.default.fileExists(atPath: downloadPath) + let contentsOfDir = try? FileManager.default.contentsOfDirectory(atPath: downloadPath) + let directoryEmpty = (contentsOfDir == nil) || (contentsOfDir!.isEmpty) + + guard !directoryExists || directoryEmpty else { return } + + let _ = DatasetUtilities.downloadResource( + filename: "cifar-10-binary", fileExtension: "tar.gz", + remoteRoot: location.deletingLastPathComponent(), localStorageDirectory: directory) +} + +func _loadCIFARFile(named name: String, in directory: URL, normalizing: Bool = true) -> [TensorPair] { + let path = directory.appendingPathComponent("cifar-10-batches-bin/\(name)").path + + let imageCount = 10000 + guard let fileContents = try? Data(contentsOf: URL(fileURLWithPath: path)) else { + printError("Could not read dataset file: \(name)") + exit(-1) + } + guard fileContents.count == 30_730_000 else { + printError( + "Dataset file \(name) should have 30730000 bytes, instead had \(fileContents.count)") + exit(-1) + } + + var bytes: [UInt8] = [] + var labels: [Int64] = [] + + let imageByteSize = 3073 + for imageIndex in 0..(shape: [imageCount], scalars: labels) + let images = Tensor(shape: [imageCount, 3, 32, 32], scalars: bytes) + + // Transpose from the CIFAR-provided N(CHW) to TF's default NHWC. + var imageTensor = Tensor(images.transposed(permutation: [0, 2, 3, 1])) + + // The value of mean and std were calculated with the following Swift code: + // ``` + // import TensorFlow + // import Datasets + // import Foundation + // let urlString = "https://storage.googleapis.com/s4tf-hosted-binaries/datasets/CIFAR10/cifar-10-binary.tar.gz" + // let cifar = CIFAR10(batchSize: 50000, + // remoteBinaryArchiveLocation: URL(string: urlString)!, + // normalizing: false) + // for batch in cifar.training.sequenced() { + // let images = Tensor(batch.first) / 255.0 + // let mom = images.moments(squeezingAxes: [0,1,2]) + // print("mean: \(mom.mean) std: \(sqrt(mom.variance))") + // } + // ``` + if normalizing { + let mean = Tensor( + [0.4913996898, + 0.4821584196, + 0.4465309242]) + let std = Tensor( + [0.2470322324, + 0.2434851280, + 0.2615878417]) + imageTensor = ((imageTensor / 255.0) - mean) / std + } + + return (0..(labelTensor[$0])) } + +} + +func _loadCIFARTrainingFiles(localStorageDirectory: URL, normalizing: Bool = true) -> [TensorPair] { + let data = (1..<6).map { + _loadCIFARFile(named: "data_batch_\($0).bin", in: localStorageDirectory, normalizing: normalizing) + } + return data.reduce([], +) +} + +func _loadCIFARTestFile(localStorageDirectory: URL, normalizing: Bool = true) -> [TensorPair] { + return _loadCIFARFile(named: "test_batch.bin", in: localStorageDirectory, normalizing: normalizing) +} \ No newline at end of file diff --git a/Datasets/CMakeLists.txt b/Datasets/CMakeLists.txt index a9708ee10f1..661b09744b8 100644 --- a/Datasets/CMakeLists.txt +++ b/Datasets/CMakeLists.txt @@ -1,5 +1,6 @@ add_library(Datasets CIFAR10/CIFAR10.swift + CIFAR10/OldCIFAR10.swift DatasetUtilities.swift COCO/COCO.swift COCO/COCODataset.swift diff --git a/Datasets/ImageClassificationDataset.swift b/Datasets/ImageClassificationDataset.swift index 58e89995b62..3e88251888b 100644 --- a/Datasets/ImageClassificationDataset.swift +++ b/Datasets/ImageClassificationDataset.swift @@ -14,11 +14,39 @@ import TensorFlow import Batcher +import ModelSupport public protocol ImageClassificationDataset { - associatedtype SourceDataSet: Collection - where SourceDataSet.Element == TensorPair, SourceDataSet.Index == Int - init(batchSize: Int) - var training: Batcher { get } - var test: Batcher { get } + associatedtype SourceDataSet: Collection + where SourceDataSet.Element == TensorPair, SourceDataSet.Index == Int + init(batchSize: Int) + var training: Batcher { get } + var test: Batcher { get } } + +/// An image with a label. +public typealias LabeledImage = LabeledData, Tensor> + +/// Types whose elements represent an image classification dataset (with both +/// training and validation data). +public protocol ImageClassificationData { + /// The type of the training data, represented as a sequence of epochs, which + /// are collection of batches. + associatedtype Training: Sequence + where Training.Element: Collection, Training.Element.Element == LabeledImage + /// The type of the validation data, represented as a collection of batches. + associatedtype Validation: Collection where Validation.Element == LabeledImage + /// Creates an instance from a given `batchSize`. + init(batchSize: Int) + /// The `training` epochs. + var training: Training { get } + /// The `validation` batches. + var validation: Validation { get } + + // The following is probably going to be necessary since we can't extract that + // information from `Epochs` or `Batches`. + /// The number of samples in the `training` set. + //var trainingSampleCount: Int {get} + /// The number of samples in the `validation` set. + //var validationSampleCount: Int {get} +} \ No newline at end of file diff --git a/Examples/Custom-CIFAR10/main.swift b/Examples/Custom-CIFAR10/main.swift index f3ac805f34f..23f0d2418a1 100644 --- a/Examples/Custom-CIFAR10/main.swift +++ b/Examples/Custom-CIFAR10/main.swift @@ -23,12 +23,12 @@ let optimizer = RMSProp(for: model, learningRate: 0.0001, decay: 1e-6) print("Starting training...") -for epoch in 1...100 { +for (epoch, epochBatches) in dataset.training.prefix(100).enumerated() { Context.local.learningPhase = .training var trainingLossSum: Float = 0 var trainingBatchCount = 0 - for batch in dataset.training.sequenced() { - let (images, labels) = (batch.first, batch.second) + for batch in epochBatches { + let (images, labels) = (batch.data, batch.label) let (loss, gradients) = valueWithGradient(at: model) { model -> Tensor in let logits = model(images) return softmaxCrossEntropy(logits: logits, labels: labels) @@ -43,8 +43,8 @@ for epoch in 1...100 { var testBatchCount = 0 var correctGuessCount = 0 var totalGuessCount = 0 - for batch in dataset.test.sequenced() { - let (images, labels) = (batch.first, batch.second) + for batch in dataset.validation { + let (images, labels) = (batch.data, batch.label) let logits = model(images) testLossSum += softmaxCrossEntropy(logits: logits, labels: labels).scalarized() testBatchCount += 1 @@ -64,4 +64,4 @@ for epoch in 1...100 { Loss: \(testLossSum / Float(testBatchCount)) """ ) -} +} \ No newline at end of file diff --git a/Examples/ResNet-CIFAR10/main.swift b/Examples/ResNet-CIFAR10/main.swift index 35743f1b87f..c8fc8a31747 100644 --- a/Examples/ResNet-CIFAR10/main.swift +++ b/Examples/ResNet-CIFAR10/main.swift @@ -29,12 +29,12 @@ let optimizer = SGD(for: model, learningRate: 0.001) print("Starting training...") -for epoch in 1...10 { +for (epoch, epochBatches) in dataset.training.prefix(10).enumerated() { Context.local.learningPhase = .training var trainingLossSum: Float = 0 var trainingBatchCount = 0 - for batch in dataset.training.sequenced() { - let (images, labels) = (batch.first, batch.second) + for batch in epochBatches { + let (images, labels) = (batch.data, batch.label) let (loss, gradients) = valueWithGradient(at: model) { model -> Tensor in let logits = model(images) return softmaxCrossEntropy(logits: logits, labels: labels) @@ -49,8 +49,8 @@ for epoch in 1...10 { var testBatchCount = 0 var correctGuessCount = 0 var totalGuessCount = 0 - for batch in dataset.test.sequenced() { - let (images, labels) = (batch.first, batch.second) + for batch in dataset.validation { + let (images, labels) = (batch.data, batch.label) let logits = model(images) testLossSum += softmaxCrossEntropy(logits: logits, labels: labels).scalarized() testBatchCount += 1 @@ -70,4 +70,4 @@ for epoch in 1...10 { Loss: \(testLossSum / Float(testBatchCount)) """ ) -} +} \ No newline at end of file diff --git a/Support/CMakeLists.txt b/Support/CMakeLists.txt index faef1151801..68f27a6a34b 100644 --- a/Support/CMakeLists.txt +++ b/Support/CMakeLists.txt @@ -13,6 +13,7 @@ add_library(ModelSupport Checkpoints/Protobufs/versions.pb.swift Checkpoints/SnappyDecompression.swift FileManagement.swift + LabeledData.swift Image.swift Stderr.swift Text/BytePairEncoder.swift diff --git a/Support/LabeledData.swift b/Support/LabeledData.swift new file mode 100644 index 00000000000..ed6ea7160ca --- /dev/null +++ b/Support/LabeledData.swift @@ -0,0 +1,43 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import TensorFlow + +// Note: This is a struct and not a tuple because we need the `Collatable` +// conformance below. +/// A tuple (data, label) that can be used to train a deep learning model. +/// +/// - Parameter `Data`: the type of the input. +/// - Parameter `Label`: the type of the target. +public struct LabeledData { + /// The `data` of our sample (usually used as input for a model). + public let data: Data + /// The `label` of our sample (usually used as target for a model). + public let label: Label + + /// Creates an instance from `data` and `label`. + public init(data: Data, label: Label) { + self.data = data + self.label = label + } +} + +extension LabeledData: Collatable where Data: Collatable, Label: Collatable { + /// Creates an instance from collating `samples`. + public init(collating samples: BatchSamples) + where BatchSamples.Element == Self { + self.init(data: .init(collating: samples.map(\.data)), + label: .init(collating: samples.map(\.label))) + } +} diff --git a/Tests/DatasetsTests/CIFAR10/CIFAR10Tests.swift b/Tests/DatasetsTests/CIFAR10/CIFAR10Tests.swift index 2103157c78b..ec1c15f01bb 100644 --- a/Tests/DatasetsTests/CIFAR10/CIFAR10Tests.swift +++ b/Tests/DatasetsTests/CIFAR10/CIFAR10Tests.swift @@ -7,6 +7,7 @@ final class CIFAR10Tests: XCTestCase { func testCreateCIFAR10() { let dataset = CIFAR10( batchSize: 1, + entropy: SystemRandomNumberGenerator(), remoteBinaryArchiveLocation: URL( string: @@ -16,12 +17,14 @@ final class CIFAR10Tests: XCTestCase { verify(dataset) } - func verify(_ dataset: CIFAR10) { + func verify(_ dataset: CIFAR10) { var totalCount = 0 - for example in dataset.training.sequenced() { - XCTAssertTrue((0..<10).contains(example.second[0].scalar!)) - XCTAssertEqual(example.first.shape, [1, 32, 32, 3]) - totalCount += 1 + for epochBatches in dataset.training.prefix(1){ + for batch in epochBatches { + XCTAssertTrue((0..<10).contains(batch.label[0].scalar!)) + XCTAssertEqual(batch.data.shape, [1, 32, 32, 3]) + totalCount += 1 + } } XCTAssertEqual(totalCount, 50000) } @@ -29,6 +32,7 @@ final class CIFAR10Tests: XCTestCase { func testNormalizeCIFAR10() { let dataset = CIFAR10( batchSize: 50000, + entropy: SystemRandomNumberGenerator(), remoteBinaryArchiveLocation: URL( string: @@ -38,14 +42,16 @@ final class CIFAR10Tests: XCTestCase { let targetMean = Tensor([0, 0, 0]) let targetStd = Tensor([1, 1, 1]) - for batch in dataset.training.sequenced() { - let images = Tensor(batch.first) - let mean = images.mean(squeezingAxes: [0, 1, 2]) - let std = images.standardDeviation(squeezingAxes: [0, 1, 2]) - XCTAssertTrue(targetMean.isAlmostEqual(to: mean, - tolerance: 1e-6)) - XCTAssertTrue(targetStd.isAlmostEqual(to: std, - tolerance: 1e-5)) + for epochBatches in dataset.training.prefix(1){ + for batch in epochBatches { + let images = Tensor(batch.data) + let mean = images.mean(squeezingAxes: [0, 1, 2]) + let std = images.standardDeviation(squeezingAxes: [0, 1, 2]) + XCTAssertTrue(targetMean.isAlmostEqual(to: mean, + tolerance: 1e-6)) + XCTAssertTrue(targetStd.isAlmostEqual(to: std, + tolerance: 1e-5)) + } } } } @@ -55,4 +61,4 @@ extension CIFAR10Tests { ("testCreateCIFAR10", testCreateCIFAR10), ("testNormalizeCIFAR10", testNormalizeCIFAR10), ] -} +} \ No newline at end of file