Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
.DS_Store
.swiftpm
cifar-10-batches-py/
cifar-10-batches-bin/
118 changes: 85 additions & 33 deletions CIFAR/Data.swift
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright 2018 The TensorFlow Authors. All Rights Reserved.
// Copyright 2019 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand All @@ -12,24 +12,63 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import Python
import Foundation
import FoundationNetworking
import TensorFlow

/// Use Python and shell calls to download and extract the CIFAR-10 tarball if not already done
/// This can fail for many reasons (e.g. lack of `wget`, `tar`, or an Internet connection)
func downloadCIFAR10IfNotPresent(to directory: String = ".") {
let subprocess = Python.import("subprocess")
let path = Python.import("os.path")
let filepath = "\(directory)/cifar-10-batches-py"
let isdir = Bool(path.isdir(filepath))!
if !isdir {
print("Downloading CIFAR data...")
let command = "wget -nv -O- https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz | tar xzf - -C \(directory)"
subprocess.call(command, shell: true)
let downloadPath = "\(directory)/cifar-10-batches-bin"
let directoryExists = FileManager.default.fileExists(atPath: downloadPath)

if !directoryExists {
print("Downloading CIFAR dataset...")
let archivePath = "\(directory)/cifar-10-binary.tar.gz"
let archiveExists = FileManager.default.fileExists(atPath: archivePath)
if !archiveExists {
print("Archive missing, downloading...")
do {
let downloadedFile = try Data(
contentsOf: URL(
string: "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz")!)
try downloadedFile.write(to: URL(fileURLWithPath: archivePath))
} catch {
fatalError("Could not download CIFAR dataset, error: \(error)")
}
}

print("Archive downloaded, processing...")

#if os(macOS)
let tarLocation = "/usr/bin/tar"
#else
let tarLocation = "/bin/tar"
#endif

if #available(macOS 10.13, *) {
let task = Process()
task.executableURL = URL(fileURLWithPath: tarLocation)
task.arguments = ["xzf", archivePath]
do {
try task.run()
task.waitUntilExit()
} catch {
print("CIFAR extraction failed with error: \(error)")
}
} else {
fatalError("Process() is missing from this platform")
}

do {
try FileManager.default.removeItem(atPath: archivePath)
} catch {
fatalError("Could not remove archive, error: \(error)")
}

print("Unarchiving completed")
}
}

extension Tensor where Scalar : _TensorFlowDataTypeCompatible {
extension Tensor where Scalar: _TensorFlowDataTypeCompatible {
public var _tfeTensorHandle: _AnyTensorHandle {
TFETensorHandle(_owning: handle._cTensorHandle)
}
Expand All @@ -54,51 +93,64 @@ struct Example: TensorGroup {
data = Tensor<Float>(handle: TensorHandle<Float>(handle: _handles[dataIndex]))
}

public var _tensorHandles: [_AnyTensorHandle] { [label._tfeTensorHandle, data._tfeTensorHandle] }
public var _tensorHandles: [_AnyTensorHandle] {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to delete _tensorHandles and init(_handles:) now. They are automatically derived.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In line with that, in the Helpers.swift, has differentiableReduce been moved into swift-apis? I remember some discussion about that recently, forgot where things ended up. If so, I could pull all that out, too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's been moved to swift-apis. However I don't think the swift-apis in toolchains is up to date. Feel free to send a PR to apple/swift to update utils/update_checkout/update-checkout-config.json.

[label._tfeTensorHandle, data._tfeTensorHandle]
}
}

// Each CIFAR data file is provided as a Python pickle of NumPy arrays
func loadCIFARFile(named name: String, in directory: String = ".") -> Example {
downloadCIFAR10IfNotPresent(to: directory)
let np = Python.import("numpy")
let pickle = Python.import("pickle")
let path = "\(directory)/cifar-10-batches-py/\(name)"
let f = Python.open(path, "rb")
let res = pickle.load(f, encoding: "bytes")
let path = "\(directory)/cifar-10-batches-bin/\(name)"

let imageCount = 10000
guard let fileContents = try? Data(contentsOf: URL(fileURLWithPath: path)) else {
fatalError("Could not read dataset file: \(name)")
}
guard fileContents.count == 30_730_000 else {
fatalError(
"Dataset file \(name) should have 30730000 bytes, instead had \(fileContents.count)")
}

let bytes = res[Python.bytes("data", encoding: "utf8")]
let labels = res[Python.bytes("labels", encoding: "utf8")]
var bytes: [UInt8] = []
var labels: [Int64] = []

let imageByteSize = 3073
for imageIndex in 0..<imageCount {
let baseAddress = imageIndex * imageByteSize
labels.append(Int64(fileContents[baseAddress]))

bytes.append(contentsOf: fileContents[(baseAddress + 1)..<(baseAddress + 3073)])
}

let labelTensor = Tensor<Int64>(numpy: np.array(labels))!
let images = Tensor<UInt8>(numpy: bytes)!
let imageCount = images.shape[0]
let labelTensor = Tensor<Int64>(shape: [imageCount], scalars: labels)
let images = Tensor<UInt8>(shape: [imageCount, 3, 32, 32], scalars: bytes)

// reshape and transpose from the provided N(CHW) to TF default NHWC
let imageTensor = Tensor<Float>(images
.reshaped(to: [imageCount, 3, 32, 32])
.transposed(withPermutations: [0, 2, 3, 1]))
// transpose from the provided N(CHW) to TF default NHWC
let imageTensor = Tensor<Float>(
images.transposed(withPermutations: [0, 2, 3, 1]))

let mean = Tensor<Float>([0.485, 0.456, 0.406])
let std = Tensor<Float>([0.229, 0.224, 0.225])
let std = Tensor<Float>([0.229, 0.224, 0.225])
let imagesNormalized = ((imageTensor / 255.0) - mean) / std

return Example(label: Tensor<Int32>(labelTensor), data: imagesNormalized)
}

func loadCIFARTrainingFiles() -> Example {
let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0)") }
let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0).bin") }
return Example(
label: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.label }),
data: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.data })
)
}

func loadCIFARTestFile() -> Example {
return loadCIFARFile(named: "test_batch")
return loadCIFARFile(named: "test_batch.bin")
}

func loadCIFAR10() -> (
training: Dataset<Example>, test: Dataset<Example>) {
training: Dataset<Example>, test: Dataset<Example>
) {
let trainingDataset = Dataset<Example>(elements: loadCIFARTrainingFiles())
let testDataset = Dataset<Example>(elements: loadCIFARTestFile())
return (training: trainingDataset, test: testDataset)
Expand Down
7 changes: 1 addition & 6 deletions CIFAR/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,7 @@ classification on the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) da
## Setup

You'll need [the latest version][INSTALL] of Swift for TensorFlow
installed and added to your path. Additionally, the data loader requires Python
3.x (rather than Python 2.7), `wget`, and `numpy`.

> Note: For macOS, you need to set up the `PYTHON_LIBRARY` to help the Swift for
> TensorFlow find the `libpython3.<minor-version>.dylib` file, e.g., in
> `homebrew`.
installed and added to your path.

To train the default model, run:

Expand Down
2 changes: 0 additions & 2 deletions CIFAR/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
// limitations under the License.

import TensorFlow
import Python
PythonLibrary.useVersion(3)

let batchSize = 100

Expand Down
120 changes: 85 additions & 35 deletions ResNet/Data.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,27 +12,63 @@
// See the License for the specific language governing permissions and
// limitations under the License.

import Python
import Foundation
import FoundationNetworking
import TensorFlow

/// Use Python and shell calls to download and extract the CIFAR-10 tarball if not already done
/// This can fail for many reasons (e.g. lack of `wget`, `tar`, or an Internet connection)
func downloadCIFAR10IfNotPresent(to directory: String = ".") {
let subprocess = Python.import("subprocess")
let path = Python.import("os.path")
let filepath = "\(directory)/cifar-10-batches-py"
let isdir = Bool(path.isdir(filepath))!
if !isdir {
print("Downloading CIFAR data...")
let command = """
wget -nv -O- https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz | tar xzf - \
-C \(directory)
"""
subprocess.call(command, shell: true)
let downloadPath = "\(directory)/cifar-10-batches-bin"
let directoryExists = FileManager.default.fileExists(atPath: downloadPath)

if !directoryExists {
print("Downloading CIFAR dataset...")
let archivePath = "\(directory)/cifar-10-binary.tar.gz"
let archiveExists = FileManager.default.fileExists(atPath: archivePath)
if !archiveExists {
print("Archive missing, downloading...")
do {
let downloadedFile = try Data(
contentsOf: URL(
string: "https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz")!)
try downloadedFile.write(to: URL(fileURLWithPath: archivePath))
} catch {
fatalError("Could not download CIFAR dataset, error: \(error)")
}
}

print("Archive downloaded, processing...")

#if os(macOS)
let tarLocation = "/usr/bin/tar"
#else
let tarLocation = "/bin/tar"
#endif

if #available(macOS 10.13, *) {
let task = Process()
task.executableURL = URL(fileURLWithPath: tarLocation)
task.arguments = ["xzf", archivePath]
do {
try task.run()
task.waitUntilExit()
} catch {
print("CIFAR extraction failed with error: \(error)")
}
} else {
fatalError("Process() is missing from this platform")
}

do {
try FileManager.default.removeItem(atPath: archivePath)
} catch {
fatalError("Could not remove archive, error: \(error)")
}

print("Unarchiving completed")
}
}

extension Tensor where Scalar : _TensorFlowDataTypeCompatible {
extension Tensor where Scalar: _TensorFlowDataTypeCompatible {
public var _tfeTensorHandle: _AnyTensorHandle {
TFETensorHandle(_owning: handle._cTensorHandle)
}
Expand All @@ -57,50 +93,64 @@ struct Example: TensorGroup {
data = Tensor<Float>(handle: TensorHandle<Float>(handle: _handles[dataIndex]))
}

public var _tensorHandles: [_AnyTensorHandle] { [label._tfeTensorHandle, data._tfeTensorHandle] }
public var _tensorHandles: [_AnyTensorHandle] {
[label._tfeTensorHandle, data._tfeTensorHandle]
}
}

// Each CIFAR data file is provided as a Python pickle of NumPy arrays
func loadCIFARFile(named name: String, in directory: String = ".") -> Example {
downloadCIFAR10IfNotPresent(to: directory)
let np = Python.import("numpy")
let pickle = Python.import("pickle")
let path = "\(directory)/cifar-10-batches-py/\(name)"
let f = Python.open(path, "rb")
let res = pickle.load(f, encoding: "bytes")
let path = "\(directory)/cifar-10-batches-bin/\(name)"

let imageCount = 10000
guard let fileContents = try? Data(contentsOf: URL(fileURLWithPath: path)) else {
fatalError("Could not read dataset file: \(name)")
}
guard fileContents.count == 30_730_000 else {
fatalError(
"Dataset file \(name) should have 30730000 bytes, instead had \(fileContents.count)")
}

let bytes = res[Python.bytes("data", encoding: "utf8")]
let labels = res[Python.bytes("labels", encoding: "utf8")]
var bytes: [UInt8] = []
var labels: [Int64] = []

let imageByteSize = 3073
for imageIndex in 0..<imageCount {
let baseAddress = imageIndex * imageByteSize
labels.append(Int64(fileContents[baseAddress]))

bytes.append(contentsOf: fileContents[(baseAddress + 1)..<(baseAddress + 3073)])
}

let labelTensor = Tensor<Int64>(numpy: np.array(labels))!
let images = Tensor<UInt8>(numpy: bytes)!
let imageCount = images.shape[0]
let labelTensor = Tensor<Int64>(shape: [imageCount], scalars: labels)
let images = Tensor<UInt8>(shape: [imageCount, 3, 32, 32], scalars: bytes)

// reshape and transpose from the provided N(CHW) to TF default NHWC
let imageTensor = Tensor<Float>(images
.reshaped(to: [imageCount, 3, 32, 32])
.transposed(withPermutations: [0, 2, 3, 1]))
// transpose from the provided N(CHW) to TF default NHWC
let imageTensor = Tensor<Float>(
images.transposed(withPermutations: [0, 2, 3, 1]))

let mean = Tensor<Float>([0.485, 0.456, 0.406])
let std = Tensor<Float>([0.229, 0.224, 0.225])
let std = Tensor<Float>([0.229, 0.224, 0.225])
let imagesNormalized = ((imageTensor / 255.0) - mean) / std

return Example(label: Tensor<Int32>(labelTensor), data: imagesNormalized)
}

func loadCIFARTrainingFiles() -> Example {
let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0)") }
let data = (1..<6).map { loadCIFARFile(named: "data_batch_\($0).bin") }
return Example(
label: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.label }),
data: Raw.concat(concatDim: Tensor<Int32>(0), data.map { $0.data })
)
}

func loadCIFARTestFile() -> Example {
return loadCIFARFile(named: "test_batch")
return loadCIFARFile(named: "test_batch.bin")
}

func loadCIFAR10() -> (training: Dataset<Example>, test: Dataset<Example>) {
func loadCIFAR10() -> (
training: Dataset<Example>, test: Dataset<Example>
) {
let trainingDataset = Dataset<Example>(elements: loadCIFARTrainingFiles())
let testDataset = Dataset<Example>(elements: loadCIFARTestFile())
return (training: trainingDataset, test: testDataset)
Expand Down
9 changes: 2 additions & 7 deletions ResNet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,13 @@ dataset.
## Setup

You'll need [the latest version][INSTALL] of Swift for TensorFlow
installed and added to your path. Additionally, the data loader requires Python
3.x (rather than Python 2.7), `wget`, and `numpy`.

> Note: For macOS, you need to set up the `PYTHON_LIBRARY` to help the Swift for
> TensorFlow find the `libpython3.<minor-version>.dylib` file, e.g., in
> `homebrew`.
installed and added to your path.

To train the model on CIFAR-10, run:

```
cd swift-models
swift run ResNet
swift run -c release ResNet
```

[INSTALL]: (https://github.com/tensorflow/swift/blob/master/Installation.md)
2 changes: 0 additions & 2 deletions ResNet/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
// limitations under the License.

import TensorFlow
import Python
PythonLibrary.useVersion(3)

let batchSize = 100

Expand Down