-
Notifications
You must be signed in to change notification settings - Fork 150
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,79 @@ | ||||||
| import Foundation | ||||||
| import ModelSupport | ||||||
|
|
||||||
| /// A callback-based handler for logging the statistics to CSV file. | ||||||
|
||||||
| /// A callback-based handler for logging the statistics to CSV file. | |
| /// A callback-based handler for logging statistics to a CSV file. |
I would consider not leading with “callback-based.” In fact, "logging" and "CSV File" are kind of implied by the name. So the best description would explain what's being logged. “Statistics” is good, but what kind fo statistics? Training statistics, maybe? Maybe this should be thought of as “a log file” rather than “a logger?”
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated the doc comment. Any reason why this is LogFile? (Logger makes sense to me since its a logger not a file)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This thing can only be constructed to log to a path in the filesystem, and when we invoke its one public method, log, it appends to that file. Properly used, if you have multiple instances, each instance logs to a different file. So instances have a 1-1 correspondence to log files. The word Logger doesn't imply any of those things. A more general Logger might be an interesting abstraction, but it would have a different API.
If I see this code, I know exactly what's happening.
// No argument label needed, arguably, because when you construct a thing called “File” with
// a string, the string is obviously a path.
LogFile eventLog(s)
eventLog.append(blah, blah, blah)With Logger and log, it's less clear.
xihui-wu marked this conversation as resolved.
Show resolved
Hide resolved
xihui-wu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
xihui-wu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
xihui-wu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems very unlikely to me that we actually need to store path in addition to foundationFile. Consider whether it can/should be dropped because you can get it from foundationFile.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same. Removed foundationFile and kept the path.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, I'm not sure you want to open and close the file for every line logged, though. (I am presuming that the foundationFile object keeps the file open)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How did you “resolve” this comment?
xihui-wu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
xihui-wu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
xihui-wu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
xihui-wu marked this conversation as resolved.
Show resolved
Hide resolved
xihui-wu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
xihui-wu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
xihui-wu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
xihui-wu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
xihui-wu marked this conversation as resolved.
Show resolved
Hide resolved
xihui-wu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
xihui-wu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
xihui-wu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
xihui-wu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
dabrahams marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
xihui-wu marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it important to do this atomically if there are multiple writers to the same log file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does FileHandle support atomically write ? (Also, we won't expect multiple files written at same time I think)
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| // Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| import Foundation | ||
|
|
||
| let progressBarLength = 30 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. length in what unit? I suppose it's probably characters. Being a top-level declaration this should have a doc comment, and that would be a perfect place to put the answer. Is this the number of |
||
|
|
||
| /// A callback-based handler for printing the training or validation progress. | ||
| public class ProgressPrinter { | ||
| public var liveStatistics: Bool | ||
|
|
||
| /// Create an instance that prints progress during the training loop. | ||
| /// The progress contains a dynamic progress bar followed by statistics of metrics. | ||
| /// | ||
| /// - Parameters: | ||
| /// - liveStatistics: whether or not show the statistics lively on each batch. | ||
| public init(liveStatistics: Bool = true) { | ||
| self.liveStatistics = liveStatistics | ||
| } | ||
|
|
||
| /// The callback used to hook into the TrainingLoop for printing progress. | ||
| /// | ||
| /// An example of the progress would be: | ||
| /// Epoch 1/12 | ||
| /// 468/468 [==============================] - loss: 0.4819 - accuracy: 0.8513 | ||
| /// 79/79 [==============================] - loss: 0.1520 - accuracy: 0.9521 | ||
| /// | ||
| /// - Parameters: | ||
| /// - loop: The TrainingLoop where an event has occurred. | ||
| /// - event: The training or validation event that this callback is responding to. | ||
| public func print<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws { | ||
| switch event { | ||
| case .epochStart: | ||
| guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount else { | ||
| return | ||
| } | ||
|
|
||
| Swift.print("Epoch \(epochIndex + 1)/\(epochCount)") | ||
| case .batchEnd: | ||
| guard let batchIndex = loop.batchIndex, let batchCount = loop.batchCount else { | ||
| return | ||
| } | ||
|
|
||
| let progressBar = formatProgressBar( | ||
| progress: Float(batchIndex + 1) / Float(batchCount), length: progressBarLength) | ||
| var stats: String = "" | ||
| if liveStatistics || (batchCount == (batchIndex + 1)) { | ||
| stats = formatStats(loop.lastStatsLog) | ||
| } | ||
|
|
||
| Swift.print( | ||
| "\r\(batchIndex + 1)/\(batchCount) \(progressBar)\(stats)", | ||
| terminator: "" | ||
| ) | ||
| fflush(stdout) | ||
| case .epochEnd: | ||
| Swift.print("") | ||
| case .validationStart: | ||
| Swift.print("") | ||
| default: | ||
| return | ||
| } | ||
| } | ||
|
|
||
| func formatProgressBar(progress: Float, length: Int) -> String { | ||
| let progressSteps = Int(round(Float(length) * progress)) | ||
| let leading = String(repeating: "=", count: progressSteps) | ||
| let separator: String | ||
| let trailing: String | ||
| if progressSteps < progressBarLength { | ||
| separator = ">" | ||
| trailing = String(repeating: ".", count: progressBarLength - progressSteps - 1) | ||
| } else { | ||
| separator = "" | ||
| trailing = "" | ||
| } | ||
| return "[\(leading)\(separator)\(trailing)]" | ||
| } | ||
|
|
||
| func formatStats(_ stats: [(String, Float)]?) -> String { | ||
| var result = "" | ||
| if let stats = stats { | ||
| for stat in stats { | ||
| result += " - \(stat.0): \(String(format: "%.4f", stat.1))" | ||
| } | ||
| } | ||
| return result | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,87 @@ | ||||||
| // Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||||||
| // | ||||||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
| // you may not use this file except in compliance with the License. | ||||||
| // You may obtain a copy of the License at | ||||||
| // | ||||||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||||||
| // | ||||||
| // Unless required by applicable law or agreed to in writing, software | ||||||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||||||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
| // See the License for the specific language governing permissions and | ||||||
| // limitations under the License. | ||||||
| import TensorFlow | ||||||
|
|
||||||
| /// A callback-based handler for recording statistics. | ||||||
| /// | ||||||
| /// Data produced by this handler can be used by ProgressPrinter, CVSLogger, etc. | ||||||
| public class StatisticsRecorder { | ||||||
| public var liveStatistics: Bool | ||||||
|
|
||||||
| var metricMeasurers: [MetricsMeasurer] | ||||||
|
|
||||||
| /// Create an instance that records 'metrics' during the training loop. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| /// | ||||||
| /// Recording happens every batch by default or | ||||||
| /// only when last batch ends when 'liveStatistics' is set to false. | ||||||
| /// | ||||||
| /// - Parameters: | ||||||
| /// - liveStatistics: whether or not record lively on each batch. | ||||||
| /// This has an impact on performance, due to materialization of tensors, updating values | ||||||
| /// on every batch can reduce training speed by up to 30%. | ||||||
| /// - metrics: an array of TrainingMetrics to record. | ||||||
| public init(liveStatistics: Bool = true, metrics: [TrainingMetrics]) { | ||||||
| self.liveStatistics = liveStatistics | ||||||
| metricMeasurers = metrics.map { $0.measurer } | ||||||
| } | ||||||
|
|
||||||
| /// The callback used to hook into the TrainingLoop for recording statistics. | ||||||
| /// | ||||||
| /// It will record the statistics into lastStatsLog in the loop where other | ||||||
| /// callbacks can consume from. | ||||||
| /// | ||||||
| /// - Parameters: | ||||||
| /// - loop: The TrainingLoop where an event has occurred. | ||||||
| /// - event: The training or validation event that this callback is responding to. | ||||||
| public func record<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws { | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, this looks like it should be a mutating method on |
||||||
| switch event { | ||||||
| case .trainingStart, .validationStart: | ||||||
| resetMetricMeasurers() | ||||||
| case .batchEnd: | ||||||
| if let loss = loop.lastStepLoss, let output = loop.lastStepOutput, | ||||||
| let target = loop.lastStepTarget | ||||||
| { | ||||||
| accumulateMetrics(loss: loss, predictions: output, labels: target) | ||||||
| } | ||||||
|
|
||||||
| if let batchIndex = loop.batchIndex, let batchCount = loop.batchCount { | ||||||
| if liveStatistics || (batchCount == (batchIndex + 1)) { | ||||||
| loop.lastStatsLog = computeMetrics() | ||||||
| } | ||||||
| } | ||||||
| default: | ||||||
| return | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| func resetMetricMeasurers() { | ||||||
| for index in metricMeasurers.indices { | ||||||
| metricMeasurers[index].reset() | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| func accumulateMetrics<Output, Target>(loss: Tensor<Float>, predictions: Output, labels: Target) { | ||||||
| for index in metricMeasurers.indices { | ||||||
| metricMeasurers[index].accumulate(loss: loss, predictions: predictions, labels: labels) | ||||||
| } | ||||||
| } | ||||||
|
|
||||||
| func computeMetrics() -> [(String, Float)] { | ||||||
| var result: [(String, Float)] = [] | ||||||
| for measurer in metricMeasurers { | ||||||
| result.append((measurer.name, measurer.measure())) | ||||||
| } | ||||||
| return result | ||||||
| } | ||||||
| } | ||||||
Uh oh!
There was an error while loading. Please reload this page.