This project provides a hands-on implementation and exploration of Generative Pre-trained Transformer (GPT) models using PyTorch. It covers the fundamental building blocks of the Transformer architecture, data preparation pipelines for large text datasets, and fine-tuning techniques for adapting pre-trained models to specific downstream tasks like instruction following and text classification.
Goal: To demonstrate a practical understanding of modern Large Language Model (LLM) architecture, pre-training data handling, and fine-tuning methodologies relevant to AI Engineering roles.
- Key Features
- Technical Implementation
- Setup
- Usage
- Project Structure
- Demonstrated Skills
- Technologies Used
- Potential Future Work
- GPT Model Implementation: A clear PyTorch implementation of the decoder-only Transformer architecture, including Multi-Head Self-Attention, Layer Normalization, GELU activation, and Positional Embeddings.
- Pre-trained Weight Loading: Functionality to download and load weights from official OpenAI GPT-2 checkpoints (various sizes) into the custom model implementation.
- Data Preparation Pipelines: Scripts for downloading, preprocessing, and tokenizing large datasets suitable for language model training:
- FineWeb-Edu: For pre-training (demonstrated in
prep_fineweb.py). - HellaSwag: For evaluating common-sense reasoning (demonstrated in
prep_hellaswag.py). - Instruction Dataset: Custom JSON loading for instruction fine-tuning (
finetune_instruction.py). - SMS Spam Dataset: Loading and preprocessing for classification fine-tuning (
finetune_spam_classifier.py).
- FineWeb-Edu: For pre-training (demonstrated in
- Instruction Fine-tuning: Adapting a pre-trained GPT-2 model to follow natural language instructions based on an Alpaca-style dataset. Includes custom data loading, padding/masking strategies, and generation logic.
- Classification Fine-tuning: Adapting a pre-trained GPT-2 model for a binary classification task (SMS spam detection) by modifying the output head and selectively training layers.
- Evaluation:
- Standard loss and accuracy metrics for training and validation.
- Text generation capabilities for qualitative assessment.
- LLM-as-Judge evaluation using Ollama (e.g., Llama 3.1) to score the quality of instruction-following responses (
finetune_instruction.py).
- Hardware Acceleration: Supports training and inference on CUDA GPUs or Apple Silicon (MPS) where available.
The heart of the project is the GPTModel class, which implements a decoder-only Transformer architecture inspired by GPT-2.
- Embeddings: Uses separate embeddings for tokens (
nn.Embedding) and positions (nn.Embedding). - Transformer Blocks: Composed of multiple
TransformerBlocklayers. Each block contains:MultiHeadAttention: Implements multi-head self-attention with optional masking for causal language modeling and dropout. Includes optional query, key, value biases (qkv_bias).LayerNorm: Custom Layer Normalization implementation.FeedForward: A position-wise feed-forward network using GELU activation.- Residual Connections & Dropout: Standard residual connections (
x = x + dropout(sublayer(norm(x)))) are used around the attention and feed-forward layers.
- Output Head: A final Layer Normalization followed by a Linear layer (
out_head) projects the final hidden states to the vocabulary size for generating token probabilities. - Weight Loading: The
load_weights_into_gptfunction carefully maps and reshapes weights from OpenAI's TensorFlow checkpoints into the PyTorch model's layers, handling differences in naming conventions and tensor dimensions.
Efficient data handling is crucial for LLMs.
prep_fineweb.py:- Downloads the FineWeb-Edu dataset (a large, filtered web crawl) using the
datasetslibrary. - Uses
tiktoken(GPT-2 encoding) for fast tokenization. - Employs multiprocessing (
multiprocessing.Pool) to parallelize tokenization across CPU cores. - Implements sharding: Saves the tokenized data into multiple smaller files (
.npy) of a fixed token count (shard_size) to handle datasets that don't fit into memory. Includes logic to split documents across shards correctly. - Designates the first shard as the validation set.
- Downloads the FineWeb-Edu dataset (a large, filtered web crawl) using the
prep_hellaswag.py:- Downloads the HellaSwag dataset (multiple-choice common-sense reasoning).
- Provides functions (
render_example,iterate_examples) to process the JSONL data, tokenize context and endings usingtiktoken, and format them into tensors suitable for evaluation (calculating loss on different endings). - Includes an example evaluation loop using Hugging Face's
transformerslibrary implementation of GPT-2 for comparison/validation purposes (though the main focus is the custom model).
This script demonstrates fine-tuning a pre-loaded GPT-2 model (e.g., gpt2-medium) on an instruction-following task.
- Data Loading: Reads a JSON dataset (Alpaca format assumed:
instruction,input,output). - Formatting: Creates prompts by combining instruction, input (optional), and response fields with specific separators (
### Instruction:,### Response:). InstructionDataset&DataLoader: Uses a custom PyTorchDatasetandDataLoader.custom_collate_fn: A critical function that handles batching sequences of varying lengths. It pads sequences to the maximum length in the batch using a specifiedpad_token_idand creates target tensors for language modeling loss. Importantly, it sets the loss target toignore_index(-100) for padding tokens and potentially for the prompt tokens to focus training only on generating the response (though the current implementation seems to train on the whole sequence).- Training Loop: Implements a standard fine-tuning loop using
AdamWoptimizer, calculating cross-entropy loss, and performing backpropagation. Includes periodic evaluation on a validation set and generation of sample outputs. - Generation: Implements a
generatefunction with options for temperature scaling and top-k sampling for more diverse outputs, along with an End-Of-Sequence (eos_id) check. - LLM-as-Judge Evaluation: After fine-tuning, it generates responses for a test set and uses an external LLM (via Ollama, e.g., Llama 3.1) to score the quality of the generated responses against the ground truth, providing an automated quality assessment metric.
This script adapts the pre-trained GPT-2 model for SMS spam classification.
- Data Loading: Downloads the SMS Spam Collection dataset, preprocesses it using
pandas, creates a balanced dataset (downsampling the majority class 'ham'), and splits it into train/validation/test sets. SpamDataset: Custom dataset that tokenizes text messages and pads them to a maximum length.- Model Adaptation:
- Freezing Layers: Most of the pre-trained model parameters are frozen (
param.requires_grad = False) to preserve the learned language representations. - Modifying Output Head: The original
out_head(mapping to vocabulary size) is replaced with a newnn.Linearlayer mapping the final hidden state dimension (emb_dim) to the number of classes (2 for spam/ham). - Selective Unfreezing: Only the parameters of the newly added output head, the final
LayerNorm, and the lastTransformerBlockare unfrozen (requires_grad = True) for training. This is a common and efficient fine-tuning strategy.
- Freezing Layers: Most of the pre-trained model parameters are frozen (
- Training & Evaluation:
- Uses
AdamWoptimizer on the unfrozen parameters. - Calculates classification accuracy (
calc_accuracy_loader) and cross-entropy loss (calc_loss_batch,calc_loss_loader), focusing on the logits produced for the last token position of the input sequence. - Plots training/validation loss and accuracy over epochs.
- Uses
- Inference: Provides a
classify_reviewfunction to classify new text messages using the fine-tuned model.
- Clone the repository:
git clone https://github.com/mhuang448/llm-from-scratch.git cd llm-from-scratch - Create a virtual environment (recommended):
python -m venv venv source venv/bin/activate # On Windows use `venv\Scripts\activate`
- Install dependencies:
pip install -r requirements.txt # Likely dependencies include: torch, tiktoken, numpy, datasets, tqdm, requests, pandas, matplotlib, transformers, psutil # Make sure to create a requirements.txt file!
- (Optional) Install Ollama: For LLM-as-Judge evaluation in
finetune_instruction.py, follow the instructions at https://ollama.com/ to install and run Ollama. Pull a model like Llama 3.1:ollama pull llama3.1.
- FineWeb-Edu:
python prep_fineweb.py # This will download data and create tokenized shards in ./edu_fineweb10B/ - HellaSwag (Download only): The
prep_hellaswag.pyscript is primarily for evaluation setup but includes download logic triggered byiterate_examples. Running the evaluation script (see below) will handle this. - SMS Spam: The
finetune_spam_classifier.pyscript handles downloading and preprocessing automatically on the first run. - Instruction Data: Place your
instruction-data.jsonfile in the project's root directory (or update the path infinetune_instruction.py).
-
Instruction Fine-tuning:
# Ensure Ollama is running if you want the final evaluation step ollama serve & # Run in background (example, may vary by OS) # Run the fine-tuning script (adjust model size, epochs, etc. inside the script) python finetune_instruction.py # This will download GPT-2 weights, fine-tune, save the model to instruction_tuned_medium.pth (or similar), # generate examples, and run Ollama evaluation.
-
Spam Classification Fine-tuning:
# Run the fine-tuning script (adjust model size, epochs, etc. inside the script) python finetune_spam_classifier.py # This will download data/weights, fine-tune, save the model to spam_classifier.pth, # plot metrics, and show example classifications.
- HellaSwag (using Hugging Face model):
python prep_hellaswag.py -m gpt2-xl -d cuda # Or gpt2, gpt2-medium, etc. Use -d cpu if no GPU - Instruction Following (LLM-as-Judge): Performed automatically at the end of
finetune_instruction.py. Requires Ollama to be running. - Spam Classification: Accuracy/Loss metrics are printed and plotted during and after running
finetune_spam_classifier.py.
.
├── gpt_model.py # Core GPT model implementation & weight loading
├── finetune_instruction.py # Script for instruction fine-tuning & evaluation
├── finetune_spam_classifier.py # Script for spam classification fine-tuning
├── prep_fineweb.py # Script for downloading and tokenizing FineWeb-Edu
├── prep_hellaswag.py # Script for downloading and evaluating HellaSwag
├── instruction-data.json # Example instruction dataset (User needs to provide)
├── edu_fineweb10B/ # Directory for FineWeb tokenized shards (created by prep_fineweb.py)
├── hellaswag/ # Directory for HellaSwag data (created by prep_hellaswag.py)
├── gpt2/ # Directory for downloaded GPT-2 model weights (created by scripts)
├── *.csv # Train/Val/Test splits for spam data (created by finetune_spam_classifier.py)
├── *.pth # Saved fine-tuned model weights (created by fine-tuning scripts)
├── *.pdf # Plots generated by fine-tuning scripts
├── requirements.txt # Project dependencies (User should create)
└── README.md # This file
- Deep Learning Frameworks: Proficient use of PyTorch for building and training complex neural networks.
- Transformer Architecture: Deep understanding and implementation of core components like Multi-Head Self-Attention, Positional Encodings, Layer Normalization, and Residual Connections.
- LLM Fine-tuning: Practical application of various fine-tuning strategies:
- Full model fine-tuning (implied in instruction tuning, though depends on optimizer setup).
- Parameter-efficient fine-tuning principles (layer freezing, adapting output heads for classification).
- Data Engineering for AI:
- Handling large datasets (sharding, efficient tokenization).
- Preprocessing diverse data formats (JSON, CSV, web text).
- Creating custom Datasets and DataLoaders in PyTorch, including sophisticated collation (padding, masking).
- Dataset balancing techniques.
- Tokenization: Using standard tokenizers (
tiktoken) effectively. - MLOps Concepts:
- Model persistence (saving/loading weights).
- Integration with pre-trained models (GPT-2).
- Evaluation methodologies (loss, accuracy, generation, LLM-as-Judge).
- Dependency management.
- Software Engineering Practices: Modular code structure, clear function definitions, use of virtual environments.
- Python 3.x
- PyTorch: Core deep learning framework.
- tiktoken: Fast BPE tokenizer from OpenAI.
- NumPy: Numerical operations.
- Pandas: Data manipulation for CSV files.
- Hugging Face
datasets: Downloading large public datasets. - Matplotlib: Plotting training metrics.
- tqdm: Progress bars.
- requests: Downloading files (used in
prep_hellaswag.py). - Ollama (Optional): Local LLM runner for evaluation.
- psutil (Optional): Checking if Ollama process is running.
- Implement more advanced fine-tuning techniques (e.g., QLoRA).
- Integrate with ML experiment tracking tools (e.g., Weights & Biases, MLflow).
- Add support for different model architectures or pre-trained weights.
- Implement distributed training for larger models/datasets.
- Develop a simple API (e.g., using FastAPI) to serve the fine-tuned models.
- Expand evaluation suites (e.g., ROUGE scores for summarization, BLEU for translation if applicable tasks are added).