Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 47 additions & 17 deletions MNIST/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -71,48 +71,78 @@ struct Classifier: Layer {
}

let epochCount = 12
let batchSize = 100
let batchSize = 128

func minibatch<Scalar>(in x: Tensor<Scalar>, at index: Int) -> Tensor<Scalar> {
let start = index * batchSize
return x[start..<start+batchSize]
}

let (images, numericLabels) = readMNIST(imagesFile: "train-images-idx3-ubyte",
labelsFile: "train-labels-idx1-ubyte")
let labels = Tensor<Float>(oneHotAtIndices: numericLabels, depth: 10)
let (trainImages, trainNumericLabels) = readMNIST(imagesFile: "train-images-idx3-ubyte",
labelsFile: "train-labels-idx1-ubyte")
let trainLabels = Tensor<Float>(oneHotAtIndices: trainNumericLabels, depth: 10)

let (testImages, testNumericLabels) = readMNIST(imagesFile: "t10k-images-idx3-ubyte",
labelsFile: "t10k-labels-idx1-ubyte")
let testLabels = Tensor<Float>(oneHotAtIndices: testNumericLabels, depth: 10)

var classifier = Classifier()
let optimizer = RMSProp(for: classifier)

let optimizer = Adam(for: classifier)

print("Beginning training...")

struct Statistics {
var correctGuessCount: Int = 0
var totalGuessCount: Int = 0
var totalLoss: Float = 0
}

// The training loop.
for epoch in 1...epochCount {
var correctGuessCount = 0
var totalGuessCount = 0
var totalLoss: Float = 0
for i in 0 ..< Int(labels.shape[0]) / batchSize {
let x = minibatch(in: images, at: i)
let y = minibatch(in: numericLabels, at: i)
var trainStats = Statistics()
var testStats = Statistics()
Context.local.learningPhase = .training
for i in 0 ..< Int(trainLabels.shape[0]) / batchSize {
let x = minibatch(in: trainImages, at: i)
let y = minibatch(in: trainNumericLabels, at: i)
// Compute the gradient with respect to the model.
let 𝛁model = classifier.gradient { classifier -> Tensor<Float> in
let ŷ = classifier(x)
let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== y
correctGuessCount += Int(Tensor<Int32>(correctPredictions).sum().scalarized())
totalGuessCount += batchSize
trainStats.correctGuessCount += Int(
Tensor<Int32>(correctPredictions).sum().scalarized())
trainStats.totalGuessCount += batchSize
let loss = softmaxCrossEntropy(logits: ŷ, labels: y)
totalLoss += loss.scalarized()
trainStats.totalLoss += loss.scalarized()
return loss
}
// Update the model's differentiable variables along the gradient vector.
optimizer.update(&classifier.allDifferentiableVariables, along: 𝛁model)
}

let accuracy = Float(correctGuessCount) / Float(totalGuessCount)
Context.local.learningPhase = .inference
for i in 0 ..< Int(testLabels.shape[0]) / batchSize {
let x = minibatch(in: testImages, at: i)
let y = minibatch(in: testNumericLabels, at: i)
// Compute loss on test set
let ŷ = classifier(x)
let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== y
testStats.correctGuessCount += Int(Tensor<Int32>(correctPredictions).sum().scalarized())
testStats.totalGuessCount += batchSize
let loss = softmaxCrossEntropy(logits: ŷ, labels: y)
testStats.totalLoss += loss.scalarized()
}

let trainAccuracy = Float(trainStats.correctGuessCount) / Float(trainStats.totalGuessCount)
let testAccuracy = Float(testStats.correctGuessCount) / Float(testStats.totalGuessCount)
print("""
[Epoch \(epoch)] \
Loss: \(totalLoss), \
Accuracy: \(correctGuessCount)/\(totalGuessCount) (\(accuracy))
Training Loss: \(trainStats.totalLoss), \
Training Accuracy: \(trainStats.correctGuessCount)/\(trainStats.totalGuessCount) \
(\(trainAccuracy)), \
Test Loss: \(testStats.totalLoss), \
Test Accuracy: \(testStats.correctGuessCount)/\(testStats.totalGuessCount) \
(\(testAccuracy))
""")
}
Binary file added MNIST/t10k-images-idx3-ubyte
Binary file not shown.
Binary file added MNIST/t10k-labels-idx1-ubyte
Binary file not shown.