-
Notifications
You must be signed in to change notification settings - Fork 150
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,74 @@ | ||
| import Foundation | ||
| import ModelSupport | ||
|
|
||
| public enum CSVLoggerError: Error { | ||
| case InvalidPath | ||
| } | ||
|
|
||
| /// A handler for logging training and validation statistics to a CSV file. | ||
| public class CSVLogger { | ||
| /// The path of the file that statistics are logged to. | ||
xihui-wu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| public var path: String | ||
xihui-wu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| // True iff the header of the CSV file has been written. | ||
| fileprivate var headerWritten: Bool | ||
|
|
||
| /// Creates an instance that logs to a file with the given path. | ||
| /// | ||
| /// Throws: File system errors. | ||
| public init(path: String = "run/log.csv") throws { | ||
| self.path = path | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It seems very unlikely to me that we actually need to store
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same. Removed foundationFile and kept the path.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I'm not sure you want to open and close the file for every line logged, though. (I am presuming that the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How did you “resolve” this comment? |
||
|
|
||
| // Validate the path. | ||
| let url = URL(fileURLWithPath: path) | ||
| if url.pathExtension != "csv" { | ||
| throw CSVLoggerError.InvalidPath | ||
| } | ||
| // Create the containing directory if it is missing. | ||
| try FoundationFileSystem().createDirectoryIfMissing(at: url.deletingLastPathComponent().path) | ||
| // Initialize the file with empty string. | ||
| try FoundationFile(path: path).write(Data()) | ||
|
|
||
| self.headerWritten = false | ||
| } | ||
|
|
||
| /// Logs the statistics for the 'loop' when 'batchEnd' event happens; | ||
| /// ignoring other events. | ||
| /// | ||
| /// Throws: File system errors. | ||
| public func log<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws { | ||
xihui-wu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| switch event { | ||
| case .batchEnd: | ||
| guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount, | ||
| let batchIndex = loop.batchIndex, let batchCount = loop.batchCount, | ||
| let stats = loop.lastStatsLog | ||
| else { | ||
| // No-Op if trainingLoop doesn't set the required values for stats logging. | ||
| return | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why would any of these things be
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Merged stats into the same guard statements.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, first, I would rather see something like because it describes the situation at a semantic level rather than at the level of what some code did. (also writing "No-Op" adds nothing to what is already very obvious from the code) But that said, it seems very unlikely that the comment I'd like to see is true of any but the last property. All the others refer to values that have nothing to do with logging. So I want to know what causes It's a big design flaw in trainingLoop that it has so many optionals, and that makes this task more difficult, but I believe that's not the code you're working on(?)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These variables were originally designed to be optionals to store temporary data in protocol: https://github.com/tensorflow/swift-models/blob/master/TrainingLoop/TrainingLoop.swift#L76 The generic TrainingLoop that implements the protocol does set all these optionals. My guess on why it was designed so is that it allows other TrainingLoops not setting them. So the point on the comment is NOT "if stats logging doesn't request it then these properties will be nil", but "if these properties are nil No-op on the CSVLogger".
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then you should delete the comment. The code very clearly says that all by itself, so the comment explains nothing. I don't know if you're missing the point I'm trying to make, or you just disagree with it, but I'm doing this code review 100% for your benefit as a programmer. If the review process is blocking your progress, please feel free to just commit the changes, and decide separately about whether you want the feedback I'm giving you here. If you do, we can continue to discuss it. |
||
| } | ||
|
|
||
| if !headerWritten { | ||
dabrahams marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| try writeHeader(stats: stats) | ||
| headerWritten = true | ||
| } | ||
|
|
||
| try writeDataRow( | ||
xihui-wu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| epoch: "\(epochIndex + 1)/\(epochCount)", | ||
| batch: "\(batchIndex + 1)/\(batchCount)", | ||
| stats: stats) | ||
| default: | ||
| return | ||
| } | ||
| } | ||
|
|
||
| func writeHeader(stats: [(name: String, value: Float)]) throws { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doc comments are missing on this and the following method. |
||
| let header = (["epoch", "batch"] + stats.lazy.map { $0.name }).joined(separator: ", ") + "\n" | ||
| try FoundationFile(path: path).append(header.data(using: .utf8)!) | ||
| } | ||
|
|
||
| func writeDataRow(epoch: String, batch: String, stats: [(name: String, value: Float)]) throws { | ||
| let dataRow = ([epoch, batch] + stats.lazy.map { String($0.value) }).joined(separator: ", ") | ||
| + "\n" | ||
| try FoundationFile(path: path).append(dataRow.data(using: .utf8)!) | ||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| // Copyright 2020 The TensorFlow Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| import Foundation | ||
|
|
||
| let progressBarLength = 30 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. length in what unit? I suppose it's probably characters. Being a top-level declaration this should have a doc comment, and that would be a perfect place to put the answer. Is this the number of |
||
|
|
||
| /// A handler for printing the training and validation progress. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know what a "handler" is. That makes me suspect this class is not really representing any abstraction. Since it has no storage, there's no reason to make it a
A more appropriate idiom would be: extension TrainingLoopProtocol {
public mutating func printProgress(event: TrainingLoopEvent) throws { ... }
} |
||
| public class ProgressPrinter { | ||
| /// Print training or validation progress in response of the 'event'. | ||
| /// | ||
| /// An example of the progress would be: | ||
| /// Epoch 1/12 | ||
| /// 468/468 [==============================] - loss: 0.4819 - accuracy: 0.8513 | ||
| /// 79/79 [==============================] - loss: 0.1520 - accuracy: 0.9521 | ||
| public func print<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws { | ||
| switch event { | ||
| case .epochStart: | ||
| guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount else { | ||
| // No-Op if trainingLoop doesn't set the required values for progress printing. | ||
| return | ||
| } | ||
|
|
||
| Swift.print("Epoch \(epochIndex + 1)/\(epochCount)") | ||
| case .batchEnd: | ||
| guard let batchIndex = loop.batchIndex, let batchCount = loop.batchCount else { | ||
| // No-Op if trainingLoop doesn't set the required values for progress printing. | ||
| return | ||
| } | ||
|
|
||
| let progressBar = formatProgressBar( | ||
| progress: Float(batchIndex + 1) / Float(batchCount), length: progressBarLength) | ||
| var stats: String = "" | ||
| if let lastStatsLog = loop.lastStatsLog { | ||
| stats = formatStats(lastStatsLog) | ||
| } | ||
|
|
||
| Swift.print( | ||
| "\r\(batchIndex + 1)/\(batchCount) \(progressBar)\(stats)", | ||
| terminator: "" | ||
| ) | ||
| fflush(stdout) | ||
| case .epochEnd: | ||
| Swift.print("") | ||
| case .validationStart: | ||
| Swift.print("") | ||
| default: | ||
| return | ||
| } | ||
| } | ||
|
|
||
| func formatProgressBar(progress: Float, length: Int) -> String { | ||
| let progressSteps = Int(round(Float(length) * progress)) | ||
| let leading = String(repeating: "=", count: progressSteps) | ||
| let separator: String | ||
| let trailing: String | ||
| if progressSteps < progressBarLength { | ||
| separator = ">" | ||
| trailing = String(repeating: ".", count: progressBarLength - progressSteps - 1) | ||
| } else { | ||
| separator = "" | ||
| trailing = "" | ||
| } | ||
| return "[\(leading)\(separator)\(trailing)]" | ||
| } | ||
|
|
||
| func formatStats(_ stats: [(String, Float)]) -> String { | ||
| var result = "" | ||
| for stat in stats { | ||
| result += " - \(stat.0): \(String(format: "%.4f", stat.1))" | ||
| } | ||
| return result | ||
| } | ||
| } | ||
Uh oh!
There was an error while loading. Please reload this page.