Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Examples/LeNet-MNIST/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ var classifier = Sequential {

var optimizer = SGD(for: classifier, learningRate: 0.1)

let trainingProgress = TrainingProgress()
var trainingLoop = TrainingLoop(
training: dataset.training,
validation: dataset.validation,
optimizer: optimizer,
lossFunction: softmaxCrossEntropy,
callbacks: [trainingProgress.update])
metrics: [.accuracy],
callbacks: [CSVLogger(liveStatistics: false).log])
trainingLoop.statisticsRecorder.liveStatistics = false
trainingLoop.progressPrinter.liveStatistics = false

try! trainingLoop.fit(&classifier, epochs: epochCount, on: device)
3 changes: 1 addition & 2 deletions Examples/MobileNetV1-Imagenette/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@ let dataset = Imagenette(batchSize: 64, inputSize: .resized320, outputSize: 224,
var model = MobileNetV1(classCount: 10)
let optimizer = SGD(for: model, learningRate: 0.02, momentum: 0.9)

let trainingProgress = TrainingProgress()
var trainingLoop = TrainingLoop(
training: dataset.training,
validation: dataset.validation,
optimizer: optimizer,
lossFunction: softmaxCrossEntropy,
callbacks: [trainingProgress.update])
metrics: [.accuracy])

try! trainingLoop.fit(&model, epochs: 10, on: device)
3 changes: 1 addition & 2 deletions Examples/MobileNetV2-Imagenette/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@ let dataset = Imagenette(batchSize: 64, inputSize: .resized320, outputSize: 224,
var model = MobileNetV2(classCount: 10)
let optimizer = SGD(for: model, learningRate: 0.002, momentum: 0.9)

let trainingProgress = TrainingProgress()
var trainingLoop = TrainingLoop(
training: dataset.training,
validation: dataset.validation,
optimizer: optimizer,
lossFunction: softmaxCrossEntropy,
callbacks: [trainingProgress.update])
metrics: [.accuracy])

try! trainingLoop.fit(&model, epochs: 10, on: device)
3 changes: 1 addition & 2 deletions Examples/ResNet-CIFAR10/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@ let dataset = CIFAR10(batchSize: 10, on: device)
var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
var optimizer = SGD(for: model, learningRate: 0.001)

let trainingProgress = TrainingProgress()
var trainingLoop = TrainingLoop(
training: dataset.training,
validation: dataset.validation,
optimizer: optimizer,
lossFunction: softmaxCrossEntropy,
callbacks: [trainingProgress.update])
metrics: [.accuracy])

try! trainingLoop.fit(&model, epochs: 10, on: device)
4 changes: 2 additions & 2 deletions Examples/VGG-Imagewoof/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ public func scheduleLearningRate<L: TrainingLoopProtocol>(
}
}

let trainingProgress = TrainingProgress()
var trainingLoop = TrainingLoop(
training: dataset.training,
validation: dataset.validation,
optimizer: optimizer,
lossFunction: softmaxCrossEntropy,
callbacks: [trainingProgress.update, scheduleLearningRate])
metrics: [.accuracy],
callbacks: [scheduleLearningRate])

try! trainingLoop.fit(&model, epochs: 90, on: device)
1 change: 1 addition & 0 deletions Support/FileSystem.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ public protocol File {
func read(position: Int, count: Int) throws -> Data
func write(_ value: Data) throws
func write(_ value: Data, position: Int) throws
func append(_ value: Data) throws
}
7 changes: 7 additions & 0 deletions Support/FoundationFileSystem.swift
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,11 @@ public struct FoundationFile: File {
// TODO: Incorporate file offset.
try value.write(to: location)
}

public func append(_ value: Data) throws {
let fileHandler = try FileHandle(forUpdating: location)
try fileHandler.seekToEnd()
try fileHandler.write(contentsOf: value)
try fileHandler.close()
}
}
6 changes: 4 additions & 2 deletions TrainingLoop/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
add_library(TrainingLoop
LossFunctions.swift
Metrics.swift
TrainingLoop.swift
TrainingProgress.swift
TrainingStatistics.swift)
Callbacks/StatisticsRecorder.swift
Callbacks/ProgressPrinter.swift
Callbacks/CSVLogger.swift)
target_link_libraries(TrainingLoop PUBLIC
ModelSupport)
set_target_properties(TrainingLoop PROPERTIES
Expand Down
79 changes: 79 additions & 0 deletions TrainingLoop/Callbacks/CSVLogger.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import Foundation
import ModelSupport

/// A callback-based handler for logging the statistics to CSV file.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice doc comment. English nit:

Suggested change
/// A callback-based handler for logging the statistics to CSV file.
/// A callback-based handler for logging statistics to a CSV file.

I would consider not leading with “callback-based.” In fact, "logging" and "CSV File" are kind of implied by the name. So the best description would explain what's being logged. “Statistics” is good, but what kind fo statistics? Training statistics, maybe? Maybe this should be thought of as “a log file” rather than “a logger?”

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the doc comment. Any reason why this is LogFile? (Logger makes sense to me since its a logger not a file)

Copy link
Contributor

@dabrahams dabrahams Sep 23, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This thing can only be constructed to log to a path in the filesystem, and when we invoke its one public method, log, it appends to that file. Properly used, if you have multiple instances, each instance logs to a different file. So instances have a 1-1 correspondence to log files. The word Logger doesn't imply any of those things. A more general Logger might be an interesting abstraction, but it would have a different API.

If I see this code, I know exactly what's happening.

// No argument label needed, arguably, because when you construct a thing called “File” with
// a string, the string is obviously a path.
LogFile eventLog(s)

eventLog.append(blah, blah, blah)

With Logger and log, it's less clear.

public class CSVLogger {
public var path: String
public var liveStatistics: Bool

let foundationFS: FoundationFileSystem
let foundationFile: FoundationFile

/// Create an instance that log statistics during the training loop.
///
/// - Parameters:
/// - liveStatistics: whether or not log the statistics lively on each batch.
public init(withPath path: String = "run/log.csv", liveStatistics: Bool = true) {
self.path = path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems very unlikely to me that we actually need to store path in addition to foundationFile. Consider whether it can/should be dropped because you can get it from foundationFile.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. Removed foundationFile and kept the path.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm not sure you want to open and close the file for every line logged, though. (I am presuming that the foundationFile object keeps the file open)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did you “resolve” this comment?

self.liveStatistics = liveStatistics
self.foundationFS = FoundationFileSystem()
self.foundationFile = FoundationFile(path: path)
}

/// The callback used to hook into the TrainingLoop for logging statistics.
///
/// - Parameters:
/// - loop: The TrainingLoop where an event has occurred.
/// - event: The training or validation event that this callback is responding to.
public func log<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
switch event {
case .batchEnd:
guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount,
let batchIndex = loop.batchIndex, let batchCount = loop.batchCount
else {
break
}

if !liveStatistics && (batchIndex + 1 != batchCount) {
break
}

guard let stats = loop.lastStatsLog else {
break
}

if !FileManager.default.fileExists(atPath: path) {
try foundationFS.createDirectoryIfMissing(at: String(path[..<path.lastIndex(of: "/")!]))
try writeHeader(stats: stats)
}
try writeDataRow(
epoch: "\(epochIndex + 1)/\(epochCount)",
batch: "\(batchIndex + 1)/\(batchCount)",
stats: stats)
default:
break
}
}

func writeHeader(stats: [(String, Float)]) throws {
let head: String = (["epoch", "batch"] + stats.map { $0.0 }).joined(separator: ", ")
do {
try head.write(toFile: path, atomically: true, encoding: .utf8)
} catch {
print("Unexpected error in writing header line: \(error).")
throw error
}
}

func writeDataRow(epoch: String, batch: String, stats: [(String, Float)]) throws {
let dataRow: Data = (
"\n" + ([epoch, batch] + stats.map { String($0.1) }).joined(separator: ", ")
).data(using: .utf8)!
do {
try foundationFile.append(dataRow)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it important to do this atomically if there are multiple writers to the same log file?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does FileHandle support atomically write ? (Also, we won't expect multiple files written at same time I think)

} catch {
print("Unexpected error in writing data row: \(error).")
throw error
}
}
}
100 changes: 100 additions & 0 deletions TrainingLoop/Callbacks/ProgressPrinter.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

let progressBarLength = 30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

length in what unit? I suppose it's probably characters. Being a top-level declaration this should have a doc comment, and that would be a perfect place to put the answer. Is this the number of = signs, or the whole length printed, or…?


/// A callback-based handler for printing the training or validation progress.
public class ProgressPrinter {
public var liveStatistics: Bool

/// Create an instance that prints progress during the training loop.
/// The progress contains a dynamic progress bar followed by statistics of metrics.
///
/// - Parameters:
/// - liveStatistics: whether or not show the statistics lively on each batch.
public init(liveStatistics: Bool = true) {
self.liveStatistics = liveStatistics
}

/// The callback used to hook into the TrainingLoop for printing progress.
///
/// An example of the progress would be:
/// Epoch 1/12
/// 468/468 [==============================] - loss: 0.4819 - accuracy: 0.8513
/// 79/79 [==============================] - loss: 0.1520 - accuracy: 0.9521
///
/// - Parameters:
/// - loop: The TrainingLoop where an event has occurred.
/// - event: The training or validation event that this callback is responding to.
public func print<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
switch event {
case .epochStart:
guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount else {
return
}

Swift.print("Epoch \(epochIndex + 1)/\(epochCount)")
case .batchEnd:
guard let batchIndex = loop.batchIndex, let batchCount = loop.batchCount else {
return
}

let progressBar = formatProgressBar(
progress: Float(batchIndex + 1) / Float(batchCount), length: progressBarLength)
var stats: String = ""
if liveStatistics || (batchCount == (batchIndex + 1)) {
stats = formatStats(loop.lastStatsLog)
}

Swift.print(
"\r\(batchIndex + 1)/\(batchCount) \(progressBar)\(stats)",
terminator: ""
)
fflush(stdout)
case .epochEnd:
Swift.print("")
case .validationStart:
Swift.print("")
default:
return
}
}

func formatProgressBar(progress: Float, length: Int) -> String {
let progressSteps = Int(round(Float(length) * progress))
let leading = String(repeating: "=", count: progressSteps)
let separator: String
let trailing: String
if progressSteps < progressBarLength {
separator = ">"
trailing = String(repeating: ".", count: progressBarLength - progressSteps - 1)
} else {
separator = ""
trailing = ""
}
return "[\(leading)\(separator)\(trailing)]"
}

func formatStats(_ stats: [(String, Float)]?) -> String {
var result = ""
if let stats = stats {
for stat in stats {
result += " - \(stat.0): \(String(format: "%.4f", stat.1))"
}
}
return result
}
}
87 changes: 87 additions & 0 deletions TrainingLoop/Callbacks/StatisticsRecorder.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import TensorFlow

/// A callback-based handler for recording statistics.
///
/// Data produced by this handler can be used by ProgressPrinter, CVSLogger, etc.
public class StatisticsRecorder {
public var liveStatistics: Bool

var metricMeasurers: [MetricsMeasurer]

/// Create an instance that records 'metrics' during the training loop.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Create an instance that records 'metrics' during the training loop.
/// Creates an instance that records `metrics` during the training loop.
  • Say what it does, rather than issuing an imperative command.
  • Backticks for code voice.

///
/// Recording happens every batch by default or
/// only when last batch ends when 'liveStatistics' is set to false.
///
/// - Parameters:
/// - liveStatistics: whether or not record lively on each batch.
/// This has an impact on performance, due to materialization of tensors, updating values
/// on every batch can reduce training speed by up to 30%.
/// - metrics: an array of TrainingMetrics to record.
public init(liveStatistics: Bool = true, metrics: [TrainingMetrics]) {
self.liveStatistics = liveStatistics
metricMeasurers = metrics.map { $0.measurer }
}

/// The callback used to hook into the TrainingLoop for recording statistics.
///
/// It will record the statistics into lastStatsLog in the loop where other
/// callbacks can consume from.
///
/// - Parameters:
/// - loop: The TrainingLoop where an event has occurred.
/// - event: The training or validation event that this callback is responding to.
public func record<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, this looks like it should be a mutating method on TrainingLoopProtocol

switch event {
case .trainingStart, .validationStart:
resetMetricMeasurers()
case .batchEnd:
if let loss = loop.lastStepLoss, let output = loop.lastStepOutput,
let target = loop.lastStepTarget
{
accumulateMetrics(loss: loss, predictions: output, labels: target)
}

if let batchIndex = loop.batchIndex, let batchCount = loop.batchCount {
if liveStatistics || (batchCount == (batchIndex + 1)) {
loop.lastStatsLog = computeMetrics()
}
}
default:
return
}
}

func resetMetricMeasurers() {
for index in metricMeasurers.indices {
metricMeasurers[index].reset()
}
}

func accumulateMetrics<Output, Target>(loss: Tensor<Float>, predictions: Output, labels: Target) {
for index in metricMeasurers.indices {
metricMeasurers[index].accumulate(loss: loss, predictions: predictions, labels: labels)
}
}

func computeMetrics() -> [(String, Float)] {
var result: [(String, Float)] = []
for measurer in metricMeasurers {
result.append((measurer.name, measurer.measure()))
}
return result
}
}
Loading