Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 45 additions & 18 deletions MNIST/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -70,49 +70,76 @@ struct Classifier: Layer {
}
}

let epochCount = 12
let batchSize = 100
let epochCount: Int = 12
let batchSize: Int = 128

func minibatch<Scalar>(in x: Tensor<Scalar>, at index: Int) -> Tensor<Scalar> {
let start = index * batchSize
return x[start..<start+batchSize]
}

let (images, numericLabels) = readMNIST(imagesFile: "train-images-idx3-ubyte",
let (train_images, train_numericLabels) = readMNIST(imagesFile: "train-images-idx3-ubyte",
labelsFile: "train-labels-idx1-ubyte")
let labels = Tensor<Float>(oneHotAtIndices: numericLabels, depth: 10)
let train_labels = Tensor<Float>(oneHotAtIndices: train_numericLabels, depth: 10)

let (test_images, test_numericLabels) = readMNIST(imagesFile: "t10k-images-idx3-ubyte",
labelsFile: "t10k-labels-idx1-ubyte")
let test_labels = Tensor<Float>(oneHotAtIndices: test_numericLabels, depth: 10)

var classifier = Classifier()
let optimizer = RMSProp(for: classifier)

let optimizer = Adam(for: classifier)

print("Beginning training...")

struct stats {
var correctGuessCount: Int = 0
var totalGuessCount: Int = 0
var totalLoss: Float = 0
}

// The training loop.
for epoch in 1...epochCount {
var correctGuessCount = 0
var totalGuessCount = 0
var totalLoss: Float = 0
for i in 0 ..< Int(labels.shape[0]) / batchSize {
let x = minibatch(in: images, at: i)
let y = minibatch(in: numericLabels, at: i)
var train_stats = stats()
var test_stats = stats()
Context.local.learningPhase = .training
for i in 0 ..< Int(train_labels.shape[0]) / batchSize {
let x = minibatch(in: train_images, at: i)
let y = minibatch(in: train_numericLabels, at: i)
// Compute the gradient with respect to the model.
let 𝛁model = classifier.gradient { classifier -> Tensor<Float> in
let ŷ = classifier(x)
let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== y
correctGuessCount += Int(Tensor<Int32>(correctPredictions).sum().scalarized())
totalGuessCount += batchSize
train_stats.correctGuessCount += Int(Tensor<Int32>(correctPredictions).sum().scalarized())
train_stats.totalGuessCount += batchSize
let loss = softmaxCrossEntropy(logits: ŷ, labels: y)
totalLoss += loss.scalarized()
train_stats.totalLoss += loss.scalarized()
return loss
}
// Update the model's differentiable variables along the gradient vector.
optimizer.update(&classifier.allDifferentiableVariables, along: 𝛁model)
}

let accuracy = Float(correctGuessCount) / Float(totalGuessCount)
Context.local.learningPhase = .inference
for i in 0 ..< Int(test_labels.shape[0]) / batchSize {
let x = minibatch(in: test_images, at: i)
let y = minibatch(in: test_numericLabels, at: i)
// Compute loss on test set
let ŷ = classifier(x)
let correctPredictions = ŷ.argmax(squeezingAxis: 1) .== y
test_stats.correctGuessCount += Int(Tensor<Int32>(correctPredictions).sum().scalarized())
test_stats.totalGuessCount += batchSize
let loss = softmaxCrossEntropy(logits: ŷ, labels: y)
test_stats.totalLoss += loss.scalarized()
}

let train_accuracy = Float(train_stats.correctGuessCount) / Float(train_stats.totalGuessCount)
let test_accuracy = Float(test_stats.correctGuessCount) / Float(test_stats.totalGuessCount)
print("""
[Epoch \(epoch)] \
Loss: \(totalLoss), \
Accuracy: \(correctGuessCount)/\(totalGuessCount) (\(accuracy))
Training Loss: \(train_stats.totalLoss), \
Training Accuracy: \(train_stats.correctGuessCount)/\(train_stats.totalGuessCount) (\(train_accuracy)), \
Test Loss: \(test_stats.totalLoss), \
Test Accuracy: \(test_stats.correctGuessCount)/\(test_stats.totalGuessCount) (\(test_accuracy))
""")
}
}
Binary file added MNIST/t10k-images-idx3-ubyte
Binary file not shown.
Binary file added MNIST/t10k-labels-idx1-ubyte
Binary file not shown.