Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions Support/FoundationFileSystem.swift
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,11 @@ public struct FoundationFile: File {
try value.write(to: location)
}

/// Append data to the file.
///
/// Parameter value: data to be appended at the end.
public func append(_ value: Data) throws {
/// Appends the bytes in `suffix` to the file.
public func append(_ suffix: Data) throws {
let fileHandler = try FileHandle(forUpdating: location)
try fileHandler.seekToEnd()
try fileHandler.write(contentsOf: value)
try fileHandler.write(contentsOf: suffix)
try fileHandler.close()
}
}
14 changes: 10 additions & 4 deletions TrainingLoop/Callbacks/CSVLogger.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ public enum CSVLoggerError: Error {

/// A handler for logging training and validation statistics to a CSV file.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't know what a handler is. Still don't think this is a strong abstraction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an observer callback where TrainingLoop is designed upon. https://docs.google.com/document/d/1CtVFhV8OcQ4E7CmNyfFeZu0IgnUPx86tfQGXiHJUXz0/edit?ts=5ebef977#heading=h.b2so9ayrnyyp

You proposed to make it a function in TrainingLoop. Here are some points I'm more in favor of making callbacks wrapped in a separate classes:

  1. Decouple from TrainingLoop
  2. Use stored properties to share callback settings
  3. Follow this pattern for all callbacks

Let's discuss more offline or in seminar meeting !?!

public class CSVLogger {
/// The path of the file that statistics are logged to.
/// The path of the file to which statistics are logged.
public var path: String

// True iff the header of the CSV file has been written.
fileprivate var headerWritten: Bool

/// Creates an instance that logs to a file with the given path.
/// Creates an instance that logs to a file with the given `path`.
///
/// Throws: File system errors.
public init(path: String = "run/log.csv") throws {
Expand All @@ -32,7 +32,7 @@ public class CSVLogger {
self.headerWritten = false
}

/// Logs the statistics for the 'loop' when 'batchEnd' event happens;
/// Logs the statistics for `loop` when a `batchEnd` event happens;
/// ignoring other events.
///
/// Throws: File system errors.
Expand All @@ -43,7 +43,6 @@ public class CSVLogger {
let batchIndex = loop.batchIndex, let batchCount = loop.batchCount,
let stats = loop.lastStatsLog
else {
// No-Op if trainingLoop doesn't set the required values for stats logging.
return
}

Expand All @@ -61,11 +60,18 @@ public class CSVLogger {
}
}

/// Writes a row of column names to the file.
///
/// Column names are "epoch", "batch" and the `name` of each element of `stats`,
/// in that order.
func writeHeader(stats: [(name: String, value: Float)]) throws {
let header = (["epoch", "batch"] + stats.lazy.map { $0.name }).joined(separator: ", ") + "\n"
try FoundationFile(path: path).append(header.data(using: .utf8)!)
}

/// Appends a row of statistics log to file with the given value `epoch` for
/// "epoch" column, `batch` for "batch" column, and `value`s of `stats` for corresponding
/// columns indicated by `stats` `name`s.
func writeDataRow(epoch: String, batch: String, stats: [(name: String, value: Float)]) throws {
let dataRow = ([epoch, batch] + stats.lazy.map { String($0.value) }).joined(separator: ", ")
+ "\n"
Expand Down
31 changes: 20 additions & 11 deletions TrainingLoop/Callbacks/ProgressPrinter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,37 @@

import Foundation

let progressBarLength = 30

/// A handler for printing the training and validation progress.
///
/// The progress includes epoch and batch index the training is currently
/// in, how many percentages of a full training/validation set has been done,
/// and metric statistics.
public class ProgressPrinter {
/// Print training or validation progress in response of the 'event'.
/// Length of the complete progress bar measured in count of `=` signs.
public var progressBarLength: Int

/// Creates an instance that prints training progress with the complete
/// progress bar to be `progressBarLength` characters long.
public init(progressBarLength: Int = 30) {
self.progressBarLength = progressBarLength
}

/// Prints training or validation progress in response of the `event`.
///
/// An example of the progress would be:
/// Epoch 1/12
/// 468/468 [==============================] - loss: 0.4819 - accuracy: 0.8513
/// 79/79 [==============================] - loss: 0.1520 - accuracy: 0.9521
public func print<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
/// 58/79 [======================>.......] - loss: 0.1520 - accuracy: 0.9521
public func printProgress<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
switch event {
case .epochStart:
guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount else {
// No-Op if trainingLoop doesn't set the required values for progress printing.
return
}

Swift.print("Epoch \(epochIndex + 1)/\(epochCount)")
print("Epoch \(epochIndex + 1)/\(epochCount)")
case .batchEnd:
guard let batchIndex = loop.batchIndex, let batchCount = loop.batchCount else {
// No-Op if trainingLoop doesn't set the required values for progress printing.
return
}

Expand All @@ -46,15 +55,15 @@ public class ProgressPrinter {
stats = formatStats(lastStatsLog)
}

Swift.print(
print(
"\r\(batchIndex + 1)/\(batchCount) \(progressBar)\(stats)",
terminator: ""
)
fflush(stdout)
case .epochEnd:
Swift.print("")
print("")
case .validationStart:
Swift.print("")
print("")
default:
return
}
Expand Down
23 changes: 15 additions & 8 deletions TrainingLoop/Callbacks/StatisticsRecorder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,35 @@ import TensorFlow
///
/// Data produced by this handler can be used by ProgressPrinter, CVSLogger, etc.
public class StatisticsRecorder {
/// A Closure that returns if should call 'reset' on metricMeasurers.
/// A function that returns `true` iff recorder should call `reset`
/// on `metricMeasurers`.
public var shouldReset:
(
_ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int,
_ event: TrainingLoopEvent
) -> Bool

/// A Closure that returns if should call 'accumulate' on metricMeasurers.
/// A function that returns `true` iff recorder should call `accumulate`
/// on `metricMeasurers`.
public var shouldAccumulate:
(
_ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int,
_ event: TrainingLoopEvent
) -> Bool

/// A Closure that returns if should call 'compute' on metricMeasurers.
/// A function that returns `true` iff recorder should call `measure`
/// on `metricMeasurers`.
public var shouldCompute:
(
_ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int,
_ event: TrainingLoopEvent
) -> Bool

/// Instances of MetricsMeasurers.
/// Instances of MetricsMeasurers that you can reset accumulate and compute
/// statistics periodically.
fileprivate var metricMeasurers: [MetricsMeasurer]

/// Create an instance that records 'metrics' during the training loop.
/// Creates an instance that records `metrics` during the training loop.
public init(metrics: [TrainingMetrics]) {
metricMeasurers = metrics.map { $0.measurer }

Expand Down Expand Up @@ -70,9 +74,9 @@ public class StatisticsRecorder {
}
}

/// Recording statistics in response of the 'event'.
/// Records statistics in response of the `event`.
///
/// It will record the statistics into 'lastStatsLog' in the loop where other
/// It will record the statistics into lastStatsLog property in the `loop` where other
/// callbacks can consume from.
public func record<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
guard let batchIndex = loop.batchIndex,
Expand All @@ -83,7 +87,6 @@ public class StatisticsRecorder {
let output = loop.lastStepOutput,
let target = loop.lastStepTarget
else {
// No-Op if trainingLoop doesn't set the required values for stats recording.
return
}

Expand All @@ -101,18 +104,22 @@ public class StatisticsRecorder {
}
}

/// Resets each of the metricMeasurers.
func resetMetricMeasurers() {
for index in metricMeasurers.indices {
metricMeasurers[index].reset()
}
}

/// Lets each of the metricMeasurers accumulate data from
/// `loss`, `predictions`, `labels`.
func accumulateMetrics<Output, Target>(loss: Tensor<Float>, predictions: Output, labels: Target) {
for index in metricMeasurers.indices {
metricMeasurers[index].accumulate(loss: loss, predictions: predictions, labels: labels)
}
}

/// Lets each of the metricMeasurers compute metrics on cumulated data.
func computeMetrics() -> [(String, Float)] {
var result: [(String, Float)] = []
for measurer in metricMeasurers {
Expand Down
27 changes: 26 additions & 1 deletion TrainingLoop/Metrics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,32 +24,46 @@ public enum TrainingMetrics {
}
}

/// A protocal defining functionalities of a metrics measurer.
/// An accumulator of statistics.
public protocol MetricsMeasurer {
/// Name of the metrics.
var name: String { get set }

/// Clears accumulated data up and resets measurer to initial state.
mutating func reset()

/// Accumulates data from `loss`, `predictions`, `labels`.
mutating func accumulate<Output, Target>(
loss: Tensor<Float>?, predictions: Output?, labels: Target?
)

/// Computes metrics from cumulated data.
func measure() -> Float
}

/// A measurer for measuring loss.
public struct LossMeasurer: MetricsMeasurer {
/// Name of the LossMeasurer.
public var name: String

/// Sum of losses cumulated from batches.
private var totalBatchLoss: Float = 0

/// Count of batchs cumulated so far.
private var batchCount: Int32 = 0

/// Creates an instance with the LossMeasurer named `name`.
public init(_ name: String = "loss") {
self.name = name
}

/// Resets totalBatchLoss and batchCount to zero.
public mutating func reset() {
totalBatchLoss = 0
batchCount = 0
}

/// Adds `loss` to totalBatchLoss and increases batchCount by one.
public mutating func accumulate<Output, Target>(
loss: Tensor<Float>?, predictions: Output?, labels: Target?
) {
Expand All @@ -59,27 +73,37 @@ public struct LossMeasurer: MetricsMeasurer {
}
}

/// Computes averaged loss.
public func measure() -> Float {
return totalBatchLoss / Float(batchCount)
}
}

/// A measurer for measuring accuracy
public struct AccuracyMeasurer: MetricsMeasurer {
/// Name of the AccuracyMeasurer.
public var name: String

/// Count of correct guesses.
private var correctGuessCount: Int32 = 0

/// Count of total guesses.
private var totalGuessCount: Int32 = 0

/// Creates an instance with the AccuracyMeasurer named `name`.
public init(_ name: String = "accuracy") {
self.name = name
}

/// Resets correctGuessCount and totalGuessCount to zero.
public mutating func reset() {
correctGuessCount = 0
totalGuessCount = 0
}

/// Computes correct guess count from `loss`, `predictions` and `labels`
/// and adds it to correctGuessCount; Computes total guess count from
/// `labels` shape and adds it to totalGuessCount.
public mutating func accumulate<Output, Target>(
loss: Tensor<Float>?, predictions: Output?, labels: Target?
) {
Expand All @@ -94,6 +118,7 @@ public struct AccuracyMeasurer: MetricsMeasurer {
totalGuessCount += Int32(labels.shape[0])
}

/// Computes accuracy as percentage of correct guesses.
public func measure() -> Float {
return Float(correctGuessCount) / Float(totalGuessCount)
}
Expand Down
9 changes: 5 additions & 4 deletions TrainingLoop/TrainingLoop.swift
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public protocol TrainingLoopProtocol {
/// The loss function.
var lossFunction: LossFunction { get set }

/// The metrics
/// The metrics on which training is measured.
var metrics: [TrainingMetrics] { get set }

// Callbacks
Expand Down Expand Up @@ -220,14 +220,15 @@ where

/// Callbacks

// MARK: - The callbacks used to customize the training loop.

/// The callbacks used to customize the training loop.
public var callbacks: [TrainingLoopCallback<Self>]

// MARK: - Default callback objects

/// The callback that records the training statistics.
public var statisticsRecorder: StatisticsRecorder? = nil

/// The callback that prints the training progress.
public var progressPrinter: ProgressPrinter? = nil

/// Temporary data
Expand Down Expand Up @@ -292,7 +293,7 @@ where
self.progressPrinter = progressPrinter
self.callbacks = [
statisticsRecorder.record,
progressPrinter.print,
progressPrinter.printProgress,
] + callbacks
} else {
self.callbacks = callbacks
Expand Down