-
Notifications
You must be signed in to change notification settings - Fork 150
Replaced Python with Swift in CIFAR10 dataset loading #178
Changes from 2 commits
0763262
eec84bf
10d1b88
f29ded3
9f704dd
4d6d5f9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7,3 +7,4 @@ | |
| .DS_Store | ||
| .swiftpm | ||
| cifar-10-batches-py/ | ||
| cifar-10-batches-bin/ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,4 @@ | ||
| // Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||
| // Copyright 2019 The TensorFlow Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
|
|
@@ -12,24 +12,63 @@ | |
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| import Python | ||
| import Foundation | ||
| import FoundationNetworking | ||
| import TensorFlow | ||
|
|
||
| /// Use Python and shell calls to download and extract the CIFAR-10 tarball if not already done | ||
| /// This can fail for many reasons (e.g. lack of `wget`, `tar`, or an Internet connection) | ||
| func downloadCIFAR10IfNotPresent(to directory: String = ".") { | ||
| let subprocess = Python.import("subprocess") | ||
| let path = Python.import("os.path") | ||
| let filepath = "\(directory)/cifar-10-batches-py" | ||
| let isdir = Bool(path.isdir(filepath))! | ||
| if !isdir { | ||
| print("Downloading CIFAR data...") | ||
| let command = "wget -nv -O- https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz | tar xzf - -C \(directory)" | ||
| subprocess.call(command, shell: true) | ||
| let downloadPath = "\(directory)/cifar-10-batches-bin" | ||
| let directoryExists = FileManager.default.fileExists(atPath: downloadPath) | ||
|
|
||
| if !directoryExists { | ||
BradLarson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| print("Downloading CIFAR dataset...") | ||
| let archivePath = "\(directory)/cifar-10-binary.tar.gz" | ||
| let archiveExists = FileManager.default.fileExists(atPath: archivePath) | ||
| if !archiveExists { | ||
| print("Archive missing, downloading...") | ||
| do { | ||
| let downloadedFile = try Data( | ||
| contentsOf: URL( | ||
| string: "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz")!) | ||
| try downloadedFile.write(to: URL(fileURLWithPath: archivePath)) | ||
| } catch { | ||
| fatalError("Could not download CIFAR dataset, error: \(error)") | ||
| } | ||
| } | ||
|
|
||
| print("Archive downloaded, processing...") | ||
|
|
||
| #if os(macOS) | ||
| let tarLocation = "/usr/bin/tar" | ||
| #else | ||
| let tarLocation = "/bin/tar" | ||
| #endif | ||
|
|
||
| if #available(macOS 10.13, *) { | ||
| let task = Process() | ||
| task.executableURL = URL(fileURLWithPath: tarLocation) | ||
| task.arguments = ["xzf", archivePath] | ||
| do { | ||
| try task.run() | ||
| task.waitUntilExit() | ||
| } catch { | ||
| print("CIFAR extraction failed with error: \(error)") | ||
| } | ||
| } else { | ||
| fatalError("Process() is missing from this platform") | ||
BradLarson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| do { | ||
| try FileManager.default.removeItem(atPath: archivePath) | ||
| } catch { | ||
| fatalError("Could not remove archive, error: \(error)") | ||
BradLarson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| print("Unarchiving completed") | ||
| } | ||
| } | ||
|
|
||
| extension Tensor where Scalar : _TensorFlowDataTypeCompatible { | ||
| extension Tensor where Scalar: _TensorFlowDataTypeCompatible { | ||
| public var _tfeTensorHandle: _AnyTensorHandle { | ||
| TFETensorHandle(_owning: handle._cTensorHandle) | ||
| } | ||
|
|
@@ -54,51 +93,64 @@ struct Example: TensorGroup { | |
| data = Tensor<Float>(handle: TensorHandle<Float>(handle: _handles[dataIndex])) | ||
| } | ||
|
|
||
| public var _tensorHandles: [_AnyTensorHandle] { [label._tfeTensorHandle, data._tfeTensorHandle] } | ||
| public var _tensorHandles: [_AnyTensorHandle] { | ||
|
||
| [label._tfeTensorHandle, data._tfeTensorHandle] | ||
| } | ||
| } | ||
|
|
||
| // Each CIFAR data file is provided as a Python pickle of NumPy arrays | ||
| func loadCIFARFile(named name: String, in directory: String = ".") -> Example { | ||
| downloadCIFAR10IfNotPresent(to: directory) | ||
| let np = Python.import("numpy") | ||
| let pickle = Python.import("pickle") | ||
| let path = "\(directory)/cifar-10-batches-py/\(name)" | ||
| let f = Python.open(path, "rb") | ||
| let res = pickle.load(f, encoding: "bytes") | ||
| let path = "\(directory)/cifar-10-batches-bin/\(name)" | ||
|
|
||
| let imageCount = 10000 | ||
| guard let fileContents = try? Data(contentsOf: URL(fileURLWithPath: path)) else { | ||
| fatalError("Could not read dataset file: \(name)") | ||
| } | ||
| guard fileContents.count == 30_730_000 else { | ||
| fatalError( | ||
| "Dataset file \(name) should have 30730000 bytes, instead had \(fileContents.count)") | ||
| } | ||
|
|
||
| let bytes = res[Python.bytes("data", encoding: "utf8")] | ||
| let labels = res[Python.bytes("labels", encoding: "utf8")] | ||
| var bytes: [UInt8] = [] | ||
| var labels: [Int64] = [] | ||
|
|
||
| let imageByteSize = 3073 | ||
| for imageIndex in 0..<imageCount { | ||
| let baseAddress = imageIndex * imageByteSize | ||
| labels.append(Int64(fileContents[baseAddress])) | ||
|
|
||
BradLarson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| bytes.append(contentsOf: fileContents[(baseAddress + 1)..<(baseAddress + 3073)]) | ||
| } | ||
|
|
||
| let labelTensor = Tensor<Int64>(numpy: np.array(labels))! | ||
| let images = Tensor<UInt8>(numpy: bytes)! | ||
| let imageCount = images.shape[0] | ||
| let labelTensor = Tensor<Int64>(shape: [imageCount], scalars: labels) | ||
| let images = Tensor<UInt8>(shape: [imageCount, 3, 32, 32], scalars: bytes) | ||
|
|
||
| // reshape and transpose from the provided N(CHW) to TF default NHWC | ||
| let imageTensor = Tensor<Float>(images | ||
| .reshaped(to: [imageCount, 3, 32, 32]) | ||
| .transposed(withPermutations: [0, 2, 3, 1])) | ||
| // transpose from the provided N(CHW) to TF default NHWC | ||
BradLarson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| let imageTensor = Tensor<Float>( | ||
BradLarson marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| images.transposed(withPermutations: [0, 2, 3, 1])) | ||
|
|
||
| let mean = Tensor<Float>([0.485, 0.456, 0.406]) | ||
| let std = Tensor<Float>([0.229, 0.224, 0.225]) | ||
| let std = Tensor<Float>([0.229, 0.224, 0.225]) | ||
| let imagesNormalized = ((imageTensor / 255.0) - mean) / std | ||
|
|
||
| return Example(label: Tensor<Int32>(labelTensor), data: imagesNormalized) | ||
| } | ||
|
|
||
| func loadCIFARTrainingFiles() -> Example { | ||
| let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0)") } | ||
| let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0).bin") } | ||
| return Example( | ||
| label: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.label }), | ||
| data: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.data }) | ||
| ) | ||
| } | ||
|
|
||
| func loadCIFARTestFile() -> Example { | ||
| return loadCIFARFile(named: "test_batch") | ||
| return loadCIFARFile(named: "test_batch.bin") | ||
| } | ||
|
|
||
| func loadCIFAR10() -> ( | ||
| training: Dataset<Example>, test: Dataset<Example>) { | ||
| training: Dataset<Example>, test: Dataset<Example> | ||
| ) { | ||
| let trainingDataset = Dataset<Example>(elements: loadCIFARTrainingFiles()) | ||
| let testDataset = Dataset<Example>(elements: loadCIFARTestFile()) | ||
| return (training: trainingDataset, test: testDataset) | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.