Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion nanochat/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from nanochat.gpt import GPT, GPTConfig
from nanochat.tokenizer import get_tokenizer
from nanochat.common import setup_default_logging
from nanochat.fpquant import add_qat

# Set up logging
setup_default_logging()
Expand Down Expand Up @@ -75,7 +76,8 @@ def build_model(checkpoint_dir, step, device, phase):
# Load the model state
model.to_empty(device=device)
model.init_weights() # note: this is dumb, but we need to init the rotary embeddings. TODO: fix model re-init
model.load_state_dict(model_data, strict=True, assign=True)
model.load_state_dict(model_data, strict=False, assign=True)
model = add_qat(model, phase != "eval")
# Put the model in the right training phase / mode
if phase == "eval":
model.eval()
Expand Down
16 changes: 16 additions & 0 deletions nanochat/fpquant.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from fp_quant import FPQuantDtype, FPQuantConfig, replace_quantize_with_fp_quant_linear


def add_qat(model, store_master_weights):
model = replace_quantize_with_fp_quant_linear(
model,
fp_quant_linear_config=FPQuantConfig(
forward_dtype=FPQuantDtype.MXFP4,
forward_method="abs_max",
hadamard_group_size=128,
backward_dtype=FPQuantDtype.MXFP8,
store_master_weights=store_master_weights,
modules_to_not_convert=["lm_head"],
),
)
return model
6 changes: 3 additions & 3 deletions nanochat/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class GPTConfig:

def norm(x):
# Purely functional rmsnorm with no learnable params
return F.rms_norm(x, (x.size(-1),))
return F.rms_norm(x, (x.size(-1),)).to(torch.bfloat16)


def apply_rotary_emb(x, cos, sin):
Expand Down Expand Up @@ -121,7 +121,7 @@ def forward(self, x, cos_sin, kv_cache):
y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask)

# Re-assemble the heads side by side and project back to residual stream
y = y.transpose(1, 2).contiguous().view(B, T, -1)
y = y.transpose(1, 2).contiguous().view(B, T, -1).to(torch.bfloat16)
y = self.c_proj(y)
return y

Expand All @@ -134,7 +134,7 @@ def __init__(self, config):

def forward(self, x):
x = self.c_fc(x)
x = F.relu(x).square()
x = F.relu(x).square().to(torch.bfloat16)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully fuse cast to BF16 with act fn to not have to cast explicitly for quantized GEMM kernels again.

x = self.c_proj(x)
return x

Expand Down
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,26 @@ dependencies = [
"torch>=2.8.0",
"uvicorn>=0.36.0",
"wandb>=0.21.3",
"fp_quant>=0.3.2",
"qutlass>=0.2.0",
]

[build-system]
requires = ["maturin>=1.7,<2.0"]
build-backend = "maturin"

# build qutlass with system cuda
[tool.uv]
no-build-isolation-package = ["qutlass"]

# target torch to cuda 12.8
[tool.uv.sources]
torch = [
{ index = "pytorch-cu128" },
]

qutlass = { git = "https://github.com/IST-DASLab/qutlass", tag = "v0.2.0" }

[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
Expand Down
14 changes: 7 additions & 7 deletions run1000.sh
Original file line number Diff line number Diff line change
Expand Up @@ -74,18 +74,18 @@ python -m scripts.tok_eval
# which would decrease model performance. Possibly 2, 3 or so epochs is ~ok, but certainly not ideal and at 10+ epochs we'd
# start to overfit hard.
# 5) That's it, everything else (e.g. the learning rates) is adjusted automatically by the training script.
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=32 --device_batch_size=8
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
torchrun --nproc_per_node=8 --rdzv_backend=static --rdzv_id=speedrun -m scripts.base_train -- --depth=32 --device_batch_size=32 --run=$WANDB_RUN
torchrun --nproc_per_node=8 --rdzv_backend=static --rdzv_id=speedrun -m scripts.base_loss
torchrun --nproc_per_node=8 --rdzv_backend=static --rdzv_id=speedrun -m scripts.base_eval

# midtrain
# NOTE: ensure that we use the same device_batch_size here as the base training script.
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --device_batch_size=8 --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid
torchrun --nproc_per_node=8 --rdzv_backend=static --rdzv_id=speedrun -m scripts.mid_train -- --device_batch_size=32 --run=$WANDB_RUN
torchrun --nproc_per_node=8 --rdzv_backend=static --rdzv_id=speedrun -m scripts.chat_eval -- -i mid

# sft
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
torchrun --nproc_per_node=8 --rdzv_backend=static --rdzv_id=speedrun -m scripts.chat_sft -- --run=$WANDB_RUN
torchrun --nproc_per_node=8 --rdzv_backend=static --rdzv_id=speedrun -m scripts.chat_eval -- -i sft

# generate final report
python -m nanochat.report generate
Expand Down
8 changes: 5 additions & 3 deletions scripts/base_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from nanochat.checkpoint_manager import save_checkpoint
from nanochat.loss_eval import evaluate_bpb
from nanochat.engine import Engine
from nanochat.fpquant import add_qat
from scripts.base_eval import evaluate_model
print_banner()

Expand Down Expand Up @@ -98,8 +99,9 @@
model = GPT(model_config)
model.to_empty(device="cuda")
model.init_weights()
model = add_qat(model, True)
orig_model = model # original, uncompiled model, for saving raw model state_dict
model = torch.compile(model, dynamic=False) # TODO: dynamic True/False think through
model = torch.compile(model, dynamic=False, fullgraph=True) # TODO: dynamic True/False think through
num_params = sum(p.numel() for p in model.parameters())
print0(f"Number of parameters: {num_params:,}")
num_flops_per_token = model.estimate_flops()
Expand Down Expand Up @@ -219,7 +221,7 @@ def get_muon_momentum(it):
"My favorite color is",
"If 5*x + 3 = 13, then x is",
]
engine = Engine(model, tokenizer)
engine = Engine(orig_model, tokenizer)
for prompt in prompts:
tokens = tokenizer(prompt, prepend="<|bos|>")
with autocast_ctx:
Expand Down Expand Up @@ -284,7 +286,7 @@ def get_muon_momentum(it):
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * step / num_iterations
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
Expand Down
2 changes: 1 addition & 1 deletion scripts/chat_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
# Load the model and tokenizer
model, tokenizer, meta = load_model(source, device, phase="train", model_tag=model_tag, step=step)
orig_model = model # original, uncompiled model
# model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
model = torch.compile(model, dynamic=True) # doesn't work super well because of variable lengths of inputs
engine = Engine(model, tokenizer) # will be used for inline model evaluation only

# -----------------------------------------------------------------------------
Expand Down
6 changes: 3 additions & 3 deletions scripts/mid_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
step = None # step to load the model from (base model or midtrained model)
dtype = "bfloat16"
max_seq_len = 2048
device_batch_size = 32
device_batch_size = 64
unembedding_lr = 0.004
embedding_lr = 0.2
matrix_lr = 0.02
Expand Down Expand Up @@ -65,7 +65,7 @@
if pretrain_batch_size is not None and device_batch_size > pretrain_batch_size:
print0(f"FOOTGUN WARNING: base model training used device_batch_size {pretrain_batch_size}, did you pass in a good --device_batch_size to this script?")
orig_model = model
model = torch.compile(model, dynamic=False)
model = torch.compile(model, dynamic=False, fullgraph=True)
depth = model.config.n_layer
num_flops_per_token = model.estimate_flops()
tokens_per_fwdbwd = device_batch_size * max_seq_len # tokens per iteration for a single rank
Expand Down Expand Up @@ -248,7 +248,7 @@ def get_muon_momentum(it):
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss.item() # EMA the training loss
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1)) # debias the EMA
pct_done = 100 * progress
tok_per_sec = int(world_tokens_per_fwdbwd / dt)
tok_per_sec = int(total_batch_size / dt)
flops_per_sec = num_flops_per_token * total_batch_size / dt
promised_flops_per_sec_h100 = 989e12 * ddp_world_size # bfloat16 H100 SXM and without 2:4 sparsity
mfu = 100 * flops_per_sec / promised_flops_per_sec_h100 # in %
Expand Down
18 changes: 9 additions & 9 deletions speedrun.sh
Original file line number Diff line number Diff line change
Expand Up @@ -92,25 +92,25 @@ echo "Waiting for dataset download to complete..."
wait $DATASET_DOWNLOAD_PID

# pretrain the d20 model
torchrun --standalone --nproc_per_node=8 -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
torchrun --nproc_per_node=8 --rdzv_backend=static --rdzv_id=speedrun -m scripts.base_train -- --depth=20 --run=$WANDB_RUN
# evaluate the model on a larger chunk of train/val data and draw some samples
torchrun --standalone --nproc_per_node=8 -m scripts.base_loss
torchrun --nproc_per_node=8 --rdzv_backend=static --rdzv_id=speedrun -m scripts.base_loss
# evaluate the model on CORE tasks
torchrun --standalone --nproc_per_node=8 -m scripts.base_eval
torchrun --nproc_per_node=8 --rdzv_backend=static --rdzv_id=speedrun -m scripts.base_eval

# -----------------------------------------------------------------------------
# Midtraining (teach the model conversation special tokens, tool use, multiple choice)

# run midtraining and eval the model
torchrun --standalone --nproc_per_node=8 -m scripts.mid_train -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i mid
torchrun --nproc_per_node=8 --rdzv_backend=static --rdzv_id=speedrun -m scripts.mid_train -- --run=$WANDB_RUN
torchrun --nproc_per_node=8 --rdzv_backend=static --rdzv_id=speedrun -m scripts.chat_eval -- -i mid

# -----------------------------------------------------------------------------
# Supervised Finetuning (domain adaptation to each sequence all by itself per row)

# train sft and re-eval right away (should see a small bump)
torchrun --standalone --nproc_per_node=8 -m scripts.chat_sft -- --run=$WANDB_RUN
torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
torchrun --nproc_per_node=8 --rdzv_backend=static --rdzv_id=speedrun -m scripts.chat_sft -- --run=$WANDB_RUN
torchrun --nproc_per_node=8 --rdzv_backend=static --rdzv_id=speedrun -m scripts.chat_eval -- -i sft

# chat with the model over CLI! Leave out the -p to chat interactively
# python -m scripts.chat_cli -p "Why is the sky blue?"
Expand All @@ -123,9 +123,9 @@ torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i sft
# (optional)

# run reinforcement learning
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_rl -- --run=$WANDB_RUN
# torchrun --nproc_per_node=8 --rdzv_backend=static --rdzv_id=speedrun -m scripts.chat_rl -- --run=$WANDB_RUN
# eval the RL model only on GSM8K
# torchrun --standalone --nproc_per_node=8 -m scripts.chat_eval -- -i rl -a GSM8K
# torchrun --nproc_per_node=8 --rdzv_backend=static --rdzv_id=speedrun -m scripts.chat_eval -- -i rl -a GSM8K

# -----------------------------------------------------------------------------
# Generate the full report by putting together all the sections
Expand Down
Loading