diff --git a/.gitignore b/.gitignore index 4fea2e7e10e..87b5aaff3b9 100644 --- a/.gitignore +++ b/.gitignore @@ -9,3 +9,7 @@ cifar-10-batches-py/ cifar-10-batches-bin/ output/ +t10k-labels-idx1-ubyte +t10k-images-idx3-ubyte +train-labels-idx1-ubyte +train-images-idx3-ubyte diff --git a/Benchmarks/Benchmark.swift b/Benchmarks/Benchmark.swift new file mode 100644 index 00000000000..d50f9babd3a --- /dev/null +++ b/Benchmarks/Benchmark.swift @@ -0,0 +1,104 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +enum BenchmarkVariety { + case inferenceThroughput(batches: Int, batchSize: Int) + case trainingTime +} + +struct BenchmarkResults { + let name: String + let iterations: Int + let timings: [Double] + let variety: BenchmarkVariety +} + +extension BenchmarkResults { + var interpretedTimings: [Double] { + switch self.variety { + case let .inferenceThroughput(batches, batchSize): + return timings.map { Double(batches * batchSize) / ($0 / 1000.0) } + case .trainingTime: + return timings + } + } +} + +/// Performs the specified benchmark over a certain number of iterations and provides the result to a callback function. +func benchmark( + name: String, + iterations: Int, variety: BenchmarkVariety, + operation: () -> Void, + callback: (BenchmarkResults) -> Void +) { + var timings: [Double] = [] + for _ in 0.. Void) -> Double { + let divisor: Double = 1_000_000 + let start = Double(DispatchTime.now().uptimeNanoseconds) / divisor + body() + let end = Double(DispatchTime.now().uptimeNanoseconds) / divisor + let elapsed = end - start + return elapsed +} + +/// Provides the average and standard deviation of an array of values. +func statistics(for values: [Double]) -> (average: Double, standardDeviation: Double) { + guard values.count > 0 else { return (average: 0.0, standardDeviation: 0.0) } + guard values.count > 1 else { return (average: values.first!, standardDeviation: 0.0) } + + let average = (values.reduce(0.0) { $0 + $1 }) / Double(values.count) + + let standardDeviation = sqrt( + values.reduce(0.0) { $0 + ($1 - average) * ($1 - average) } + / Double(values.count - 1)) + + return (average: average, standardDeviation: standardDeviation) +} + +// This is a simple callback function example that only logs the result to the console. +func logResults(_ result: BenchmarkResults) { + let (average, standardDeviation) = statistics(for: result.interpretedTimings) + + switch result.variety { + case .inferenceThroughput: + print( + """ + Benchmark: \(result.name): + \tAfter \(result.iterations) iterations: + \tSamples per second: \(String(format: "%.2f", average)), standard deviation: \(String(format: "%.2f", standardDeviation)) + """) + case .trainingTime: + print( + """ + Benchmark: \(result.name): + \tAfter \(result.iterations) iterations: + \tAverage: \(String(format: "%.2f", average)) ms, standard deviation: \(String(format: "%.2f", standardDeviation)) ms + """) + } +} diff --git a/Benchmarks/Models/ImageClassificationInference.swift b/Benchmarks/Models/ImageClassificationInference.swift new file mode 100644 index 00000000000..81e787cb114 --- /dev/null +++ b/Benchmarks/Models/ImageClassificationInference.swift @@ -0,0 +1,53 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import TensorFlow +import Datasets +import ImageClassificationModels + +protocol ImageClassificationModel: Layer where Input == Tensor, Output == Tensor { + init() +} + +extension LeNet: ImageClassificationModel {} + +class ImageClassificationInference where Model: ImageClassificationModel { + // TODO: (https://github.com/tensorflow/swift-models/issues/206) Datasets should have a common + // interface to allow for them to be interchangeable in these benchmark cases. + let dataset: MNIST + var model: Model + let images: Tensor + let batches: Int + let batchSize: Int + + init(batches: Int, batchSize: Int, images: Tensor? = nil) { + self.batches = batches + self.batchSize = batchSize + self.dataset = MNIST(batchSize: batchSize) + self.model = Model() + if let providedImages = images { + self.images = providedImages + } else { + self.images = Tensor( + randomNormal: [batchSize, 28, 28, 1], mean: Tensor(0.5), + standardDeviation: Tensor(0.1), seed: (0xffeffe, 0xfffe)) + } + } + + func performInference() { + for _ in 0.. +where Model: ImageClassificationModel, Model.TangentVector.VectorSpaceScalar == Float { + // TODO: (https://github.com/tensorflow/swift-models/issues/206) Datasets should have a common + // interface to allow for them to be interchangeable in these benchmark cases. + let dataset: MNIST + let epochs: Int + let batchSize: Int + + init(epochs: Int, batchSize: Int) { + self.epochs = epochs + self.batchSize = batchSize + self.dataset = MNIST(batchSize: batchSize) + } + + func train() { + var model = Model() + // TODO: Split out the optimizer as a separate specification. + let optimizer = SGD(for: model, learningRate: 0.1) + + Context.local.learningPhase = .training + for _ in 1...epochs { + for i in 0.. Tensor in + let ŷ = model(x) + return softmaxCrossEntropy(logits: ŷ, labels: y) + } + optimizer.update(&model, along: 𝛁model) + } + } + } +} diff --git a/Benchmarks/README.md b/Benchmarks/README.md new file mode 100644 index 00000000000..18e6c352f5b --- /dev/null +++ b/Benchmarks/README.md @@ -0,0 +1,22 @@ +# Model benchmarks + +Eventually, these will contain a series of benchmarks against a variety of models in the +swift-models repository. The following benchmarks have been implemented: + +- Training LeNet against the MNIST dataset +- Performing inference with LeNet using MNIST-sized random images + +These benchmarks should provide a baseline to judge performance improvements and regressions in +Swift for TensorFlow. + +## Running benchmarks + +To begin, you'll need the [latest version of Swift for +TensorFlow](https://github.com/tensorflow/swift/blob/master/Installation.md) +installed. Make sure you've added the correct version of `swift` to your path. + +To run all benchmarks, type the following while in the swift-models directory: + +```sh +swift run -c release Benchmarks +``` \ No newline at end of file diff --git a/Benchmarks/main.swift b/Benchmarks/main.swift new file mode 100644 index 00000000000..1b2fb970116 --- /dev/null +++ b/Benchmarks/main.swift @@ -0,0 +1,29 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import ImageClassificationModels + +// LeNet-MNIST +let leNetTrainingBenchmark = ImageClassificationTraining(epochs: 1, batchSize: 128) +benchmark( + name: "LeNet-MNIST (training)", + iterations: 10, variety: .trainingTime, operation: leNetTrainingBenchmark.train, + callback: logResults) + +let leNetInferenceBenchmark = ImageClassificationInference(batches: 1000, batchSize: 1) +benchmark( + name: "LeNet-MNIST (inference)", + iterations: 10, variety: .inferenceThroughput(batches: 1000, batchSize: 1), + operation: leNetInferenceBenchmark.performInference, + callback: logResults) diff --git a/Catch/main.swift b/Catch/main.swift index 2ae144769b0..61307a18a60 100644 --- a/Catch/main.swift +++ b/Catch/main.swift @@ -89,12 +89,11 @@ extension CatchAgent { func perfectAction(for observation: Observation) -> Action { let paddleX = observation.scalars[0] let ballX = observation.scalars[1] - if paddleX > ballX { - return .right - } else if paddleX < ballX { - return .left + switch paddleX { + case ballX: return .none + case .. CIFARExamp let images = Tensor(shape: [imageCount, 3, 32, 32], scalars: bytes) // Transpose from the CIFAR-provided N(CHW) to TF's default NHWC. - let imageTensor = Tensor(images.transposed(withPermutations: [0, 2, 3, 1])) + let imageTensor = Tensor(images.transposed(permutation: [0, 2, 3, 1])) let mean = Tensor([0.485, 0.456, 0.406]) let std = Tensor([0.229, 0.224, 0.225]) @@ -125,8 +125,8 @@ func loadCIFARFile(named name: String, in directory: String = ".") -> CIFARExamp func loadCIFARTrainingFiles() -> CIFARExample { let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0).bin") } return CIFARExample( - label: Raw.concat(concatDim: Tensor(0), data.map { $0.label }), - data: Raw.concat(concatDim: Tensor(0), data.map { $0.data }) + label: _Raw.concat(concatDim: Tensor(0), data.map { $0.label }), + data: _Raw.concat(concatDim: Tensor(0), data.map { $0.data }) ) } diff --git a/Datasets/DatasetUtilities.swift b/Datasets/DatasetUtilities.swift new file mode 100644 index 00000000000..dab0b28864a --- /dev/null +++ b/Datasets/DatasetUtilities.swift @@ -0,0 +1,112 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +#if canImport(FoundationNetworking) + import FoundationNetworking +#endif + +public struct DatasetUtilities { + public static let curentWorkingDirectoryURL = URL( + fileURLWithPath: FileManager.default.currentDirectoryPath) + + public static func fetchResource( + filename: String, + remoteRoot: URL, + localStorageDirectory: URL = curentWorkingDirectoryURL + ) -> Data { + print("Loading resource: \(filename)") + + let resource = ResourceDefinition( + filename: filename, + remoteRoot: remoteRoot, + localStorageDirectory: localStorageDirectory) + + let localURL = resource.localURL + + if !FileManager.default.fileExists(atPath: localURL.path) { + print( + "File does not exist locally at expected path: \(localURL.path) and must be fetched" + ) + fetchFromRemoteAndSave(resource) + } + + do { + print("Loading local data at: \(localURL.path)") + let data = try Data(contentsOf: localURL) + print("Succesfully loaded resource: \(filename)") + return data + } catch { + fatalError("Failed to contents of resource: \(localURL)") + } + } + + struct ResourceDefinition { + let filename: String + let remoteRoot: URL + let localStorageDirectory: URL + + var localURL: URL { + localStorageDirectory.appendingPathComponent(filename) + } + + var remoteURL: URL { + remoteRoot.appendingPathComponent(filename).appendingPathExtension("gz") + } + + var archiveURL: URL { + localURL.appendingPathExtension("gz") + } + } + + static func fetchFromRemoteAndSave(_ resource: ResourceDefinition) { + let remoteLocation = resource.remoteURL + let archiveLocation = resource.archiveURL + + do { + print("Fetching URL: \(remoteLocation)...") + let archiveData = try Data(contentsOf: remoteLocation) + print("Writing fetched archive to: \(archiveLocation.path)") + try archiveData.write(to: archiveLocation) + } catch { + fatalError("Failed to fetch and save resource with error: \(error)") + } + print("Archive saved to: \(archiveLocation.path)") + + extractArchive(for: resource) + } + + static func extractArchive(for resource: ResourceDefinition) { + print("Extracting archive...") + + let archivePath = resource.archiveURL.path + + #if os(macOS) + let gunzipLocation = "/usr/bin/gunzip" + #else + let gunzipLocation = "/bin/gunzip" + #endif + + let task = Process() + task.executableURL = URL(fileURLWithPath: gunzipLocation) + task.arguments = [archivePath] + do { + try task.run() + task.waitUntilExit() + } catch { + fatalError("Failed to extract \(archivePath) with error: \(error)") + } + } +} diff --git a/Datasets/MNIST/MNIST.swift b/Datasets/MNIST/MNIST.swift index 18256088a3f..6a79b4b5d65 100644 --- a/Datasets/MNIST/MNIST.swift +++ b/Datasets/MNIST/MNIST.swift @@ -31,21 +31,27 @@ public struct MNIST { public let batchSize: Int - public init(batchSize: Int, flattening: Bool = false, normalizing: Bool = false) { + public init( + batchSize: Int, flattening: Bool = false, normalizing: Bool = false, + localStorageDirectory: URL = DatasetUtilities.curentWorkingDirectoryURL + ) { self.batchSize = batchSize - let (trainingImages, trainingLabels) = readMNIST( - imagesFile: "train-images-idx3-ubyte", - labelsFile: "train-labels-idx1-ubyte", + let (trainingImages, trainingLabels) = fetchDataset( + localStorageDirectory: localStorageDirectory, + imagesFilename: "train-images-idx3-ubyte", + labelsFilename: "train-labels-idx1-ubyte", flattening: flattening, normalizing: normalizing) + self.trainingImages = trainingImages self.trainingLabels = trainingLabels self.trainingSize = Int(trainingLabels.shape[0]) - let (testImages, testLabels) = readMNIST( - imagesFile: "t10k-images-idx3-ubyte", - labelsFile: "t10k-labels-idx1-ubyte", + let (testImages, testLabels) = fetchDataset( + localStorageDirectory: localStorageDirectory, + imagesFilename: "t10k-images-idx3-ubyte", + labelsFilename: "t10k-labels-idx1-ubyte", flattening: flattening, normalizing: normalizing) self.testImages = testImages @@ -61,36 +67,31 @@ extension Tensor { } } -/// Reads a file into an array of bytes. -func readFile(_ path: String, possibleDirectories: [String]) -> [UInt8] { - for folder in possibleDirectories { - let parent = URL(fileURLWithPath: folder) - let filePath = parent.appendingPathComponent(path) - guard FileManager.default.fileExists(atPath: filePath.path) else { - continue - } - let data = try! Data(contentsOf: filePath, options: []) - return [UInt8](data) +fileprivate func fetchDataset( + localStorageDirectory: URL, + imagesFilename: String, + labelsFilename: String, + flattening: Bool, + normalizing: Bool +) -> (images: Tensor, labels: Tensor) { + guard let remoteRoot: URL = URL(string: "http://yann.lecun.com/exdb/mnist") else { + fatalError("Failed to create MNST root url: http://yann.lecun.com/exdb/mnist") } - print("File not found: \(path)") - exit(-1) -} -/// Reads MNIST images and labels from specified file paths. -func readMNIST(imagesFile: String, labelsFile: String, flattening: Bool, normalizing: Bool) -> ( - images: Tensor, - labels: Tensor -) { - print("Reading data from files: \(imagesFile), \(labelsFile).") - let images = readFile(imagesFile, possibleDirectories: [".", "./Datasets/MNIST"]).dropFirst(16) - .map(Float.init) - let labels = readFile(labelsFile, possibleDirectories: [".", "./Datasets/MNIST"]).dropFirst(8) - .map(Int32.init) - let rowCount = labels.count - let imageHeight = 28 - let imageWidth = 28 + let imagesData = DatasetUtilities.fetchResource( + filename: imagesFilename, + remoteRoot: remoteRoot, + localStorageDirectory: localStorageDirectory) + let labelsData = DatasetUtilities.fetchResource( + filename: labelsFilename, + remoteRoot: remoteRoot, + localStorageDirectory: localStorageDirectory) + + let images = [UInt8](imagesData).dropFirst(16).map(Float.init) + let labels = [UInt8](labelsData).dropFirst(8).map(Int32.init) - print("Constructing data tensors.") + let rowCount = labels.count + let (imageWidth, imageHeight) = (28, 28) if flattening { var flattenedImages = Tensor(shape: [rowCount, imageHeight * imageWidth], scalars: images) @@ -101,8 +102,9 @@ func readMNIST(imagesFile: String, labelsFile: String, flattening: Bool, normali return (images: flattenedImages, labels: Tensor(labels)) } else { return ( - images: Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images) - .transposed(withPermutations: [0, 2, 3, 1]) / 255, // NHWC + images: + Tensor(shape: [rowCount, 1, imageHeight, imageWidth], scalars: images) + .transposed(permutation: [0, 2, 3, 1]) / 255, // NHWC labels: Tensor(labels) ) } diff --git a/GAN/main.swift b/GAN/main.swift index f57a550274f..8e09985ec34 100644 --- a/GAN/main.swift +++ b/GAN/main.swift @@ -126,7 +126,7 @@ func saveImageGrid(_ testImage: Tensor, name: String) throws { // Add padding. gridImage = gridImage.padded(forSizes: [(0, 0), (0, 0), (1, 1), (1, 1)], with: 1) // Transpose to create single image. - gridImage = gridImage.transposed(withPermutations: [0, 2, 1, 3]) + gridImage = gridImage.transposed(permutation: [0, 2, 1, 3]) gridImage = gridImage.reshaped( to: [ (imageHeight + 2) * testImageGridSize, diff --git a/Gym/CartPole/main.swift b/Gym/CartPole/main.swift index 2825e1400f5..a0ba2df1186 100644 --- a/Gym/CartPole/main.swift +++ b/Gym/CartPole/main.swift @@ -68,8 +68,8 @@ struct Episode { /// Filtering out bad/short episodes before we feed them as neural net training data. func filteringBatch( - episodes: [Episode], - actionCount: Int + episodes: [Episode], + actionCount: Int ) -> (input: Tensor, target: Tensor, episodeCount: Int, meanReward: Float) { let rewards = episodes.map { $0.reward } let rewardBound = Float(np.percentile(rewards, percentile))! @@ -111,10 +111,10 @@ func filteringBatch( } func nextBatch( - env: PythonObject, - net: Net, - batchSize: Int, - actionCount: Int + env: PythonObject, + net: Net, + batchSize: Int, + actionCount: Int ) -> [Episode] { var observationNumpy = env.reset() @@ -127,8 +127,7 @@ func nextBatch( while true { let observationPython = Tensor(numpy: observationNumpy).unwrapped() - let actionProbabilities = - softmax(net(Tensor(observationPython).reshaped(to: [1, 4]))) + let actionProbabilities = softmax(net(Tensor(observationPython).reshaped(to: [1, 4]))) let actionProbabilitiesPython = actionProbabilities[0].makeNumpyArray() let len = Python.len(actionProbabilitiesPython) assert(actionCount == Int(Python.len(actionProbabilitiesPython))) @@ -138,8 +137,10 @@ func nextBatch( // print(nextObservation) // print(reward) - steps.append(Episode.Step(observation: Tensor(observationPython), - action: Int32(actionPython).unwrapped())) + steps.append( + Episode.Step( + observation: Tensor(observationPython), + action: Int32(actionPython).unwrapped())) episodeReward += Float(reward).unwrapped() @@ -162,7 +163,8 @@ let observationSize = Int(env.observation_space.shape[0]).unwrapped() let actionCount = Int(env.action_space.n).unwrapped() // print(actionCount) -var net = Net(observationSize: Int(observationSize), hiddenSize: hiddenSize, actionCount: actionCount) +var net = Net( + observationSize: Int(observationSize), hiddenSize: hiddenSize, actionCount: actionCount) // SGD optimizer reaches convergence with ~125 mini batches, while Adam uses ~25. // let optimizer = SGD(learningRate: 0.1, momentum: 0.9) let optimizer = Adam(for: net, learningRate: 0.01) @@ -174,7 +176,7 @@ while true { let episodes = nextBatch(env: env, net: net, batchSize: batchSize, actionCount: actionCount) let (input, target, episodeCount, meanReward) = filteringBatch( - episodes: episodes, actionCount: actionCount) + episodes: episodes, actionCount: actionCount) let gradients = withLearningPhase(.training) { net.gradient { net -> Tensor in diff --git a/MiniGo/Models/GoModel.swift b/MiniGo/Models/GoModel.swift index bdd61fb99a2..d5b34425a94 100644 --- a/MiniGo/Models/GoModel.swift +++ b/MiniGo/Models/GoModel.swift @@ -54,10 +54,7 @@ struct ConvBN: Layer { // TODO(jekbradbury): thread through bias and affine boolean arguments // (behavior is correct for inference but this should be changed for training) self.conv = Conv2D(filterShape: filterShape, strides: strides, padding: padding) - self.norm = BatchNorm( - featureCount: filterShape.3, - momentum: Tensor(0.95), - epsilon: Tensor(1e-5)) + self.norm = BatchNorm(featureCount: filterShape.3, momentum: 0.95, epsilon: 1e-5) } @differentiable diff --git a/MiniGo/Models/PythonCheckpointReader.swift b/MiniGo/Models/PythonCheckpointReader.swift index 2fe1a950032..58c1746a7e8 100644 --- a/MiniGo/Models/PythonCheckpointReader.swift +++ b/MiniGo/Models/PythonCheckpointReader.swift @@ -28,7 +28,7 @@ public class PythonCheckpointReader { let countSuffix = layerCounts[layerName] == nil ? "" : "_\(layerCounts[layerName]!)" let tensorName = layerName + countSuffix + "/" + weightName // TODO(jekbradbury): support variadic dtype attrs in RawOpsGenerated - return Raw.restoreV2(prefix: StringTensor(path), + return _Raw.restoreV2(prefix: StringTensor(path), tensorNames: StringTensor([tensorName]), shapeAndSlices: StringTensor([""])) } diff --git a/MiniGo/Strategies/MCTS/MCTSModelBasePredictor.swift b/MiniGo/Strategies/MCTS/MCTSModelBasePredictor.swift index b3919320c9e..2d7cdf1f3a8 100644 --- a/MiniGo/Strategies/MCTS/MCTSModelBasePredictor.swift +++ b/MiniGo/Strategies/MCTS/MCTSModelBasePredictor.swift @@ -119,7 +119,7 @@ extension BoardState { // boardSize, featurePlanes]` order. // // Rotate our inputs to this order by transposing and reshape to a single-element batch. - return featureTensor.transposed(withPermutations: 1, 2, 0) + return featureTensor.transposed(permutation: 1, 2, 0) .reshaped(to: [1, boardSize, boardSize, 17]) } } diff --git a/Models/ImageClassification/DenseNet121.swift b/Models/ImageClassification/DenseNet121.swift new file mode 100644 index 00000000000..35a9b579f3f --- /dev/null +++ b/Models/ImageClassification/DenseNet121.swift @@ -0,0 +1,146 @@ +// Copyright 2019 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import TensorFlow + +// Original Paper: +// Densely Connected Convolutional Networks +// Gao Huang, Zhuang Liu, Laurens van der Maaten, Kilian Q. Weinberger +// https://arxiv.org/pdf/1608.06993.pdf + +public struct DenseNet121: Layer { + public var conv = Conv( + filterSize: 7, + stride: 2, + inputFilterCount: 3, + outputFilterCount: 64 + ) + public var maxpool = MaxPool2D( + poolSize: (3, 3), + strides: (2, 2), + padding: .same + ) + public var denseBlock1 = DenseBlock(repetitionCount: 6, inputFilterCount: 64) + public var transitionLayer1 = TransitionLayer(inputFilterCount: 256) + public var denseBlock2 = DenseBlock(repetitionCount: 12, inputFilterCount: 128) + public var transitionLayer2 = TransitionLayer(inputFilterCount: 512) + public var denseBlock3 = DenseBlock(repetitionCount: 24, inputFilterCount: 256) + public var transitionLayer3 = TransitionLayer(inputFilterCount: 1024) + public var denseBlock4 = DenseBlock(repetitionCount: 16, inputFilterCount: 512) + public var globalAvgPool = GlobalAvgPool2D() + public var dense: Dense + + public init(classCount: Int) { + dense = Dense(inputSize: 1024, outputSize: classCount, activation: softmax) + } + + @differentiable + public func callAsFunction(_ input: Tensor) -> Tensor { + let inputLayer = input.sequenced(through: conv, maxpool) + let level1 = inputLayer.sequenced(through: denseBlock1, transitionLayer1) + let level2 = level1.sequenced(through: denseBlock2, transitionLayer2) + let level3 = level2.sequenced(through: denseBlock3, transitionLayer3) + let output = level3.sequenced(through: denseBlock4, globalAvgPool, dense) + return output + } +} + +extension DenseNet121 { + public struct Conv: Layer { + public var batchNorm: BatchNorm + public var conv: Conv2D + + public init( + filterSize: Int, + stride: Int = 1, + inputFilterCount: Int, + outputFilterCount: Int + ) { + batchNorm = BatchNorm(featureCount: inputFilterCount) + conv = Conv2D( + filterShape: (filterSize, filterSize, inputFilterCount, outputFilterCount), + strides: (stride, stride), + padding: .same + ) + } + + @differentiable + public func callAsFunction(_ input: Tensor) -> Tensor { + conv(relu(batchNorm(input))) + } + } + + /// A pair of a 1x1 `Conv` layer and a 3x3 `Conv` layer. + public struct ConvPair: Layer { + public var conv1x1: Conv + public var conv3x3: Conv + + public init(inputFilterCount: Int, growthRate: Int) { + conv1x1 = Conv( + filterSize: 1, + inputFilterCount: inputFilterCount, + outputFilterCount: inputFilterCount * 2 + ) + conv3x3 = Conv( + filterSize: 3, + inputFilterCount: inputFilterCount * 2, + outputFilterCount: growthRate + ) + } + + @differentiable + public func callAsFunction(_ input: Tensor) -> Tensor { + let conv1Output = conv1x1(input) + let conv3Output = conv3x3(conv1Output) + return conv3Output.concatenated(with: input, alongAxis: -1) + } + } + + public struct DenseBlock: Layer { + public var pairs: [ConvPair] = [] + + public init(repetitionCount: Int, growthRate: Int = 32, inputFilterCount: Int) { + for i in 0..) -> Tensor { + pairs.differentiableReduce(input) { last, layer in + layer(last) + } + } + } + + public struct TransitionLayer: Layer { + public var conv: Conv + public var pool: AvgPool2D + + public init(inputFilterCount: Int) { + conv = Conv( + filterSize: 1, + inputFilterCount: inputFilterCount, + outputFilterCount: inputFilterCount / 2 + ) + pool = AvgPool2D(poolSize: (2, 2), strides: (2, 2), padding: .same) + } + + @differentiable + public func callAsFunction(_ input: Tensor) -> Tensor { + input.sequenced(through: conv, pool) + } + } +} diff --git a/Package.swift b/Package.swift index 20e140f6aad..302355d0218 100644 --- a/Package.swift +++ b/Package.swift @@ -18,12 +18,14 @@ let package = Package( .executable(name: "MiniGoDemo", targets: ["MiniGoDemo"]), .library(name: "MiniGo", targets: ["MiniGo"]), .executable(name: "GAN", targets: ["GAN"]), + .executable(name: "Benchmarks", targets: ["Benchmarks"]), ], targets: [ .target(name: "ImageClassificationModels", path: "Models/ImageClassification"), .target(name: "Datasets", path: "Datasets"), .target(name: "ModelSupport", path: "Support"), - .target(name: "Autoencoder", dependencies: ["Datasets", "ModelSupport"], path: "Autoencoder"), + .target( + name: "Autoencoder", dependencies: ["Datasets", "ModelSupport"], path: "Autoencoder"), .target(name: "Catch", path: "Catch"), .target(name: "Gym-FrozenLake", path: "Gym/FrozenLake"), .target(name: "Gym-CartPole", path: "Gym/CartPole"), @@ -45,5 +47,9 @@ let package = Package( .testTarget(name: "ImageClassificationTests", dependencies: ["ImageClassificationModels"]), .target(name: "Transformer", path: "Transformer"), .target(name: "GAN", dependencies: ["Datasets", "ModelSupport"], path: "GAN"), + .target( + name: "Benchmarks", + dependencies: ["Datasets", "ModelSupport", "ImageClassificationModels"], + path: "Benchmarks"), ] ) diff --git a/Support/Image.swift b/Support/Image.swift index 8cef015a6c0..2df5d537ea4 100644 --- a/Support/Image.swift +++ b/Support/Image.swift @@ -37,11 +37,11 @@ public struct Image { } public init(jpeg url: URL, byteOrdering: ByteOrdering = .rgb) { - let loadedFile = Raw.readFile(filename: StringTensor(url.absoluteString)) - let loadedJpeg = Raw.decodeJpeg(contents: loadedFile, channels: 3, dctMethod: "") + let loadedFile = _Raw.readFile(filename: StringTensor(url.absoluteString)) + let loadedJpeg = _Raw.decodeJpeg(contents: loadedFile, channels: 3, dctMethod: "") if byteOrdering == .bgr { self.imageData = .uint8( - data: Raw.reverse(loadedJpeg, dims: Tensor([false, false, false, true]))) + data: _Raw.reverse(loadedJpeg, dims: Tensor([false, false, false, true]))) } else { self.imageData = .uint8(data: loadedJpeg) } @@ -59,21 +59,21 @@ public struct Image { outputImageData = Tensor(adjustedData) } - let encodedJpeg = Raw.encodeJpeg( + let encodedJpeg = _Raw.encodeJpeg( image: outputImageData, format: .grayscale, quality: quality, xmpMetadata: "") - Raw.writeFile(filename: StringTensor(url.absoluteString), contents: encodedJpeg) + _Raw.writeFile(filename: StringTensor(url.absoluteString), contents: encodedJpeg) } public func resized(to size: (Int, Int)) -> Image { switch self.imageData { case let .uint8(data): return Image( - tensor: Raw.resizeBilinear( + tensor: _Raw.resizeBilinear( images: Tensor([data]), size: Tensor([Int32(size.0), Int32(size.1)]))) case let .float(data): return Image( - tensor: Raw.resizeBilinear( + tensor: _Raw.resizeBilinear( images: Tensor([data]), size: Tensor([Int32(size.0), Int32(size.1)]))) } diff --git a/Tests/ImageClassificationTests/Inference.swift b/Tests/ImageClassificationTests/Inference.swift index 710cc691ac6..99105c30f23 100644 --- a/Tests/ImageClassificationTests/Inference.swift +++ b/Tests/ImageClassificationTests/Inference.swift @@ -21,6 +21,15 @@ final class ImageClassificationInferenceTests: XCTestCase { override class func setUp() { Context.local.learningPhase = .inference } + + func testDenseNet121() { + let input = Tensor( + randomNormal: [1, 224, 224, 3], mean: Tensor(0.5), + standardDeviation: Tensor(0.1), seed: (0xffeffe, 0xfffe)) + let denseNet121 = DenseNet121(classCount: 1000) + let denseNet121Result = denseNet121(input) + XCTAssertEqual(denseNet121Result.shape, [1, 1000]) + } func testLeNet() { let leNet = LeNet() @@ -158,6 +167,7 @@ final class ImageClassificationInferenceTests: XCTestCase { extension ImageClassificationInferenceTests { static var allTests = [ + ("testDenseNet121", testDenseNet121), ("testLeNet", testLeNet), ("testResNet", testResNet), ("testResNetV2", testResNetV2), diff --git a/Transformer/Model.swift b/Transformer/Model.swift index 6d23851b16d..cc9181863f8 100644 --- a/Transformer/Model.swift +++ b/Transformer/Model.swift @@ -92,7 +92,7 @@ func causallyMasked(_ dotProducts: Tensor, enable: Bool = false) -> Tenso } let (queryTimeSteps, keyTimeSteps) = (dotProducts.shape[1], dotProducts.shape[2]) let ones = Tensor(ones: [1, queryTimeSteps, keyTimeSteps]) - let mask = Raw.matrixBandPart( + let mask = _Raw.matrixBandPart( ones, numLower: Tensor(Int32(-1)), numUpper: Tensor(Int32(queryTimeSteps - keyTimeSteps))) @@ -102,7 +102,7 @@ func causallyMasked(_ dotProducts: Tensor, enable: Bool = false) -> Tenso // causal mask is intentionally invisible to differentiation func _vjpCausallyMasked(_ dotProducts: Tensor, enable: Bool) -> (Tensor, (Tensor) -> Tensor) { - return (causallyMasked(dotProducts), identity) + return (causallyMasked(dotProducts, enable: enable), identity) } struct Attention: ParameterlessLayer { @@ -138,7 +138,7 @@ func splitHeads(_ input: Tensor, headCount: Int) -> Tensor { let (batchSize, timeSteps, features) = (input.shape[0], input.shape[1], input.shape[2]) let featuresPerHead = features / headCount let splitLastDim = input.reshaped(to: [batchSize, timeSteps, headCount, featuresPerHead]) - let movedToFront = splitLastDim.transposed(withPermutations: 0, 2, 1, 3) + let movedToFront = splitLastDim.transposed(permutation: 0, 2, 1, 3) return movedToFront.reshaped(to: [batchSize * headCount, timeSteps, featuresPerHead]) } @@ -149,7 +149,7 @@ func joinHeads(_ input: Tensor, headCount: Int) -> Tensor { let batchSize = generalizedBatch / headCount let features = featuresPerHead * headCount let splitFirstDim = input.reshaped(to: [batchSize, headCount, timeSteps, featuresPerHead]) - let movedToBack = splitFirstDim.transposed(withPermutations: 0, 2, 1, 3) + let movedToBack = splitFirstDim.transposed(permutation: 0, 2, 1, 3) return movedToBack.reshaped(to: [batchSize, timeSteps, features]) } @@ -173,7 +173,7 @@ func _vjpSplitQKV(_ input: Tensor) -> (AttentionInput, (AttentionInput.TangentVector) -> Tensor) { let value = splitQKV(input) return (value, { seed in - return Raw.concatV2([seed.query, seed.key, seed.value], axis: Tensor(2)) + return _Raw.concatV2([seed.query, seed.key, seed.value], axis: Tensor(2)) }) } @@ -227,10 +227,10 @@ struct EncoderLayer: Layer { size: size, headCount: headCount) selfAttentionDropout = Dropout(probability: dropProbability) - selfAttentionNorm = LayerNorm(featureCount: size, axis: 2, epsilon: Tensor(1e-5)) + selfAttentionNorm = LayerNorm(featureCount: size, axis: 2, epsilon: 1e-5) feedForward = FeedForward(size: size, hidden: 4 * size, dropProbability: dropProbability) feedForwardDropout = Dropout(probability: dropProbability) - feedForwardNorm = LayerNorm(featureCount: size, axis: 2, epsilon: Tensor(1e-5)) + feedForwardNorm = LayerNorm(featureCount: size, axis: 2, epsilon: 1e-5) } @differentiable(wrt: (self, input)) diff --git a/Transformer/Operators.swift b/Transformer/Operators.swift index 20944d45bd9..60a5e82b7f5 100644 --- a/Transformer/Operators.swift +++ b/Transformer/Operators.swift @@ -37,7 +37,7 @@ func batchedMatmul( adjointLeft: Bool = false, adjointRight: Bool = false ) -> Tensor { - return Raw.batchMatMul(left, right, adjX: adjointLeft, adjY: adjointRight) + return _Raw.batchMatMul(left, right, adjX: adjointLeft, adjY: adjointRight) } @usableFromInline diff --git a/Transformer/PythonCheckpointReader.swift b/Transformer/PythonCheckpointReader.swift index 1a6df989adf..9597cd28e25 100644 --- a/Transformer/PythonCheckpointReader.swift +++ b/Transformer/PythonCheckpointReader.swift @@ -36,7 +36,7 @@ func readTensor( scalarType: Scalar.Type ) -> Tensor { // TODO(jekbradbury): support variadic dtype attrs in RawOpsGenerated - return Raw.restoreV2(prefix: StringTensor(path), + return _Raw.restoreV2(prefix: StringTensor(path), tensorNames: StringTensor([name]), shapeAndSlices: StringTensor([""])) } @@ -73,7 +73,7 @@ extension LayerNorm: InitializableFromPythonCheckpoint { offset: readTensor(fromPath: path, name: scope + "/b", scalarType: Scalar.self), scale: readTensor(fromPath: path, name: scope + "/g", scalarType: Scalar.self), axis: -1, - epsilon: Tensor(1e-5)) + epsilon: 1e-5) } } diff --git a/Transformer/main.swift b/Transformer/main.swift index 0cc91eb2171..a3d4a4fa5f7 100644 --- a/Transformer/main.swift +++ b/Transformer/main.swift @@ -53,7 +53,7 @@ for _ in 0..<100 { let lastLogit = logits.slice( lowerBounds: [0, timeSteps - 1, 0], upperBounds: [batchSize, timeSteps, vocabSize]) / temperature - tokens = Raw.multinomial(logits: lastLogit.squeezingShape(at: 1), numSamples: Tensor(1)) + tokens = _Raw.multinomial(logits: lastLogit.squeezingShape(at: 1), numSamples: Tensor(1)) print(encoder.decode(tokens[0].makeNumpyArray()), terminator: "") } print()