Skip to content

amazon-science/adamw-sr

Stochastic Rounding for LLM Training: Theory and Practice

Authors: Kaan Ozkara, Tao Yu, Youngsuk Park

This repository includes the official code for our paper Stochastic Rounding for LLM Training: Theory and Practice (AISTATS 2025).

GPT-2 runs (nanoGPT_sr/, adapted from nanoGPT)

Dependencies:

  • pytorch
  • numpy
  • transformers for huggingface transformers
  • datasets for huggingface datasets
  • tiktoken for OpenAI's fast BPE code
  • wandb for optional logging
  • tqdm for progress bars

We use AWS PyTroch 2.1 AMI

Directory structure:

./config includes configuration file that controls the parameters in the code.

./data includes data preperation codes to download and tokenize dataset.

./adamw_sr.py our BF16+SR optimizer with shared randomness.

./model.py includes a generic GPT-2 type implementation from nanoGPT.

./train.py main training script.

./configurator.py to import configuration variables.

Data prep:

Simply call

python data/openwebtext/prepare.py

Example call:

After adjusting ./config/train_gpt2.py so that

dtype = 'bfloat16' #'float32', 'bfloat16'
mixed_precision = False
stoc_rounding = True

is set accordingly, simply run

torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py

GPT-Neo runs (NeMo-GPT/, adapted from NVIDIA NeMo)

We use NeMo 1.22 version. Note that NeMo scripts uses docker container and there is no need to download separate dependencies.

Dataset preparation: Wikipedia data (~20GB) is used for training our GPT (Instructions). The steps therein consist of downloading the data, extracting the raw data, downloading the HuggingFace GPT2 tokenizer, convert training data into memory map format. Once finished, you should have the following files:

1. gpt2-vocab.json 
2. gpt2-merges.txt
3. train_data.jsonl
4. hfbpe_gpt_training_data_text_document.bin 
5. hfbpe_gpt_training_data_text_document.idx 

Directory structure:

./NeMo-GPT all the NeMo repository files, including our optimizer in ./NeMo-GPT/nemo/core/optim/adamw_sr.py.

./model_configs include base configuration file.

./shell_scripts include the .sh files to run the experiments, run.slurm file to use SLURM to run the code in a distributed server setting.

How to run:

After downloading and processing the data, edit the run.slurm file to the corresponding configuration (.sh file) to run, and call

sbatch run.slurm

The config file will pull a docker image and install NeMo environment automatically.

Note: at the beginning of each .sh file change the hardcoded file locations according to your personal directory,

# For example:
#Dataset:
DATA_PATH="./llm-training-nemo/examples_datasets"

#Path to NeMo-GPT submodule
NEMO_PATH="./llm-training-nemo/NeMo-GPT"

#Experiment logs: Specify expt_name and date-time
EXP_PATH="./llm-training-nemo/log_megatron/nemo_$(date "+%y-%m-%d-%H-%M-%S")-gpt-megatron"

#Checkpoint and training logs
CHKPT_PATH="./llm-training-nemo/nemo_experiments"

Example custom usage

AdamW_SR optimizer is designed as a drop-in optimizer in replacement of e.g. torch.optim.AdamW with same usage: e.g., initialization, .step() and .zero_grad() functionalities. An important point the user should be careful for is that the optimizer states and model parameters should be in torch.bfloat16, otherwise, the optimizer will throw an error. Furthermore, some functionalities in torch.optim.AdamW, are not yet implemented e.g. fusedAdam, distributed multi-tensor Adam. For best results, use a 2-4x larger learning rate for maximum learning rate that you would use for regular (mixed precision/fp32) AdamW; you can keep the minimum learning rate as it is.

from adamw_sr import AdamW_SR_BF16

# Define a simple model
model = SimpleNet()

# Convert model to bf16
model.to(torch.bfloat16)

# Define the optimizer, since the model params in bf16 the optimizer will also be
optimizer = AdamW_SR_BF16(model.parameters(), lr=0.001, weight_decay=0.01)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    # Forward pass
    outputs = model(inputs)

    # Calculate the loss
    loss = criterion(outputs, targets)

    # Backward pass
    loss.backward()

    # Update the weights
    optimizer.step()

    # Zero the gradients
    optimizer.zero_grad()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
  • Other adamw-bf16 optimizers that may be useful: 1 2

Citing

If you use this work, please consider citing

@inproceedings{
ozkara2025stochastic,
title={Stochastic Rounding for {LLM} Training: Theory and Practice},
author={Kaan Ozkara and Tao Yu and Youngsuk Park},
booktitle={The 28th International Conference on Artificial Intelligence and Statistics},
year={2025},
url={https://openreview.net/forum?id=3j3NtXcc95}
}

About

No description, website, or topics provided.

Resources

License

Code of conduct

Contributing

Security policy

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •