Credit to https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
This project implements a Convolutional Neural Network (CNN) using PyTorch to classify images from the CIFAR-10 dataset. The dataset consists of 60,000 32x32 color images in 10 classes, with 6,000 images per class. The model is trained to classify images into one of the following categories:
- Airplane
- Automobile
- Bird
- Cat
- Deer
- Dog
- Frog
- Horse
- Ship
- Truck
The CIFAR-10 dataset is automatically downloaded and split into training and test sets.
-
Load and Normalize the CIFAR-10 Dataset
- Use
torchvision.datasets.CIFAR10
to load the dataset. - Apply transformations including normalization.
- Use
DataLoader
to facilitate batch processing.
- Use
-
Visualize Sample Data
- Display random training images using
matplotlib
.
- Display random training images using
-
Define the CNN Model
- The model consists of:
- Two convolutional layers (
Conv2d
) - Max pooling layers (
MaxPool2d
) - Fully connected layers (
Linear
) - Activation functions (
ReLU
)
- Two convolutional layers (
- The model consists of:
-
Define the Loss Function and Optimizer
- Use Cross-Entropy Loss (
nn.CrossEntropyLoss
) as the loss function. - Use Stochastic Gradient Descent (SGD) with momentum for optimization.
- Use Cross-Entropy Loss (
-
Train the Model
- Train for multiple epochs.
- Calculate loss at every 2000 mini-batches.
-
Test the Model
- Predict labels for test images.
- Compare predictions with ground truth.
- Evaluate the overall model accuracy.
-
Analyze Model Performance
- Compute accuracy for each class separately.
-
Save the Trained Model
- Save model parameters using
torch.save()
.
- Save model parameters using
-
Train on GPU (if available)
- Move model and data to CUDA device for acceleration.
To run this project, install the following dependencies:
pip install torch torchvision matplotlib numpy
- Increase model complexity (add more layers or use pretrained models like ResNet).
- Use data augmentation techniques to improve generalization.
- Fine-tune hyperparameters for better accuracy.
- Train for more epochs to improve performance.
This project provides a foundational approach to image classification using CNNs in PyTorch. 🚀