Skip to content
Merged
4 changes: 3 additions & 1 deletion src/axolotl/core/builders/causal.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def get_callbacks(self):
if self.cfg.include_tkps:
callbacks.append(
TokensPerSecondCallback(
self.cfg.tensor_parallel_size, self.cfg.context_parallel_size
self.cfg.tensor_parallel_size,
self.cfg.context_parallel_size,
resume_from_checkpoint=self.cfg.resume_from_checkpoint,
)
)
return callbacks
Expand Down
58 changes: 44 additions & 14 deletions src/axolotl/core/trainers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import json
import math
import os
from collections import defaultdict
Expand Down Expand Up @@ -50,6 +51,8 @@

LOG = get_logger(__name__)

TOKENS_STATE_FILE = "tokens_state."

REDUCTION_FNS = {
"mean": torch.mean,
"min": torch.min,
Expand Down Expand Up @@ -349,24 +352,34 @@ def compute_loss(
# return (loss, outputs) if return_outputs else loss

# track number of tokens for tokens per second calculation
if self.args.include_tkps:
if self.args.include_tkps and model.training:
inputs_key = "labels" if "labels" in inputs else "input_ids"
num_tokens = (inputs[inputs_key] != -100).sum()
trainable_tokens = (inputs[inputs_key] != -100).sum()
total_tokens = inputs[inputs_key].numel()

if is_distributed():
torch.distributed.all_reduce(
num_tokens, op=torch.distributed.ReduceOp.SUM
trainable_tokens, op=torch.distributed.ReduceOp.SUM
)
if hasattr(self.state, "num_tokens"):
self.state.num_tokens = (
self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu()
torch.distributed.all_reduce(
total_tokens, op=torch.distributed.ReduceOp.SUM
)
else:
self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu()

if hasattr(self.state, "total_tokens"):
self.state.total_tokens += num_tokens
else:
self.state.total_tokens = num_tokens
if not hasattr(self.state, "tokens"):
self.state.tokens = {
"trainable": torch.zeros(1),
"total": torch.zeros(1),
}

# trainable tokens for throughput and total token slots for summaries
self.state.tokens["trainable"] = (
self.state.tokens["trainable"] + trainable_tokens.detach().cpu()
)
self.state.tokens["total"] = (
self.state.tokens["total"] + torch.as_tensor(total_tokens).cpu()
)
# Store per-step trainable tokens for throughput calculation
self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu()

if self.args.orpo_alpha:
return self.orpo_compute_loss(
Expand Down Expand Up @@ -637,10 +650,14 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None:
except (ValueError, TypeError, FileNotFoundError):
pass

if self.args.include_tkps and train_eval == "train":
if (
self.args.include_tkps
and train_eval == "train"
and hasattr(self.state, "tokens")
):
# each rank will log its own tokens per second
# for logging_steps > 1 we obtain a moving average of this metric
logs["tokens_per_second_per_gpu"] = round(
logs["tokens/train_per_sec_per_gpu"] = round(
self.state.last_tokens_per_second.item() / self.args.logging_steps, 2
)
if (
Expand Down Expand Up @@ -682,6 +699,19 @@ def _save_checkpoint(self, model, trial, **kwargs):
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)

# Save total_tokens state if tracking is enabled
if self.args.include_tkps and hasattr(self.state, "tokens"):
tokens_state = {
"total": int(torch.as_tensor(self.state.tokens.get("total", 0)).item()),
"trainable": int(
torch.as_tensor(self.state.tokens.get("trainable", 0)).item()
),
}
tokens_state_path = os.path.join(output_dir, TOKENS_STATE_FILE)
with open(tokens_state_path, "w", encoding="utf-8") as f:
json.dump(tokens_state, f)

return super()._save_checkpoint(model, trial, **kwargs)

# TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged
Expand Down
53 changes: 49 additions & 4 deletions src/axolotl/utils/callbacks/tokens_per_second.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""A callback for calculating tokens per second during training."""

import json
import os
import time

import torch
Expand All @@ -10,29 +12,61 @@
TrainingArguments,
)

from axolotl.utils.logging import get_logger

LOG = get_logger(__name__)

TOKENS_STATE_FILE = "tokens_state.json"


class TokensPerSecondCallback(TrainerCallback):
"""
A callback to measure and log tokens per second during training.
Also handles saving/restoring total_tokens state across checkpoint resumes.
"""

def __init__(self, tensor_parallel_size, context_parallel_size):
def __init__(
self, tensor_parallel_size, context_parallel_size, resume_from_checkpoint=None
):
super().__init__()
self.step_time = 0.0
self.start_time = 0.0
self.non_data_parallel_size = 1
self.resume_from_checkpoint = resume_from_checkpoint
if tensor_parallel_size is not None:
self.non_data_parallel_size *= tensor_parallel_size
if context_parallel_size is not None:
self.non_data_parallel_size *= context_parallel_size

def on_train_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
): # pylint: disable=unused-argument
"""Restore total_tokens state when resuming from checkpoint."""
if not isinstance(self.resume_from_checkpoint, str):
return
tokens_state_path = os.path.join(self.resume_from_checkpoint, TOKENS_STATE_FILE)
if os.path.isfile(tokens_state_path):
with open(tokens_state_path, "r", encoding="utf-8") as f:
tokens_state = json.load(f)
state.tokens = {
"total": torch.tensor(tokens_state.get("total", 0)),
"trainable": torch.tensor(tokens_state.get("trainable", 0)),
}
LOG.info(f"Restored total_tokens: {state.tokens['total']}")

def on_step_begin(
self,
args: TrainingArguments,
state: TrainerState,
control: TrainerControl,
**kwargs,
): # pylint: disable=unused-argument
if not hasattr(state, "tokens"):
state.tokens = {"trainable": torch.zeros(1), "total": torch.zeros(1)}
self.start_time = time.perf_counter()
state.last_tokens_per_second = torch.zeros(1)

Expand All @@ -43,9 +77,10 @@ def on_step_end(
control: TrainerControl,
**kwargs,
): # pylint: disable=unused-argument
if hasattr(state, "num_tokens"):
tokens = getattr(state, "tokens", None)
if tokens and "trainable_tokens" in tokens:
step_time = time.perf_counter() - self.start_time
num_tokens_per_device = state.num_tokens.clone()
num_tokens_per_device = tokens["trainable_tokens"].clone()
# non data parallel groups have duplicated tokens, so we avoid double-counting
num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size
state.last_tokens_per_second = num_tokens_per_device / step_time
Expand All @@ -60,5 +95,15 @@ def on_log(
): # pylint: disable=unused-argument
# after logging, clear the running metrics
if hasattr(state, "last_tokens_per_second"):
logs["tokens/train_per_sec_per_gpu"] = state.last_tokens_per_second.item()
state.last_tokens_per_second.zero_()
state.num_tokens = torch.zeros(1)
tokens = getattr(state, "tokens", None)
# Clear per-step tokens after logging
if tokens and "trainable_tokens" in tokens:
tokens["trainable_tokens"] = torch.zeros_like(tokens["trainable_tokens"])

if tokens and "total" in tokens:
logs["tokens/total"] = tokens["total"].item()

if tokens and "trainable" in tokens:
logs["tokens/trainable"] = tokens["trainable"].item()
Comment on lines +105 to +109

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this duplicate log of base.py L651-652?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yess. redundant ?

33 changes: 32 additions & 1 deletion tests/e2e/patched/test_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
E2E tests for resuming training
"""

import os
import re
import subprocess

from transformers.utils import is_torch_bf16_gpu_available

from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.callbacks.tokens_per_second import TOKENS_STATE_FILE
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault

Expand Down Expand Up @@ -58,6 +60,7 @@ def test_resume_lora_packed(self, temp_dir):
"use_tensorboard": True,
"save_safetensors": True,
"save_first_step": False,
"include_tkps": True,
}
)
if is_torch_bf16_gpu_available():
Expand All @@ -68,16 +71,44 @@ def test_resume_lora_packed(self, temp_dir):
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)

initial_total_num_tokens = cfg.total_num_tokens
assert initial_total_num_tokens is not None, (
"total_num_tokens should be calculated during load_datasets"
)

train(cfg=cfg, dataset_meta=dataset_meta)

checkpoint_path = f"{temp_dir}/checkpoint-9"
tokens_state_path = os.path.join(checkpoint_path, TOKENS_STATE_FILE)
assert os.path.isfile(tokens_state_path), (
f"{TOKENS_STATE_FILE} should exist in checkpoint at {tokens_state_path}"
)

resume_cfg = cfg | DictDefault(
{
"resume_from_checkpoint": f"{temp_dir}/checkpoint-9/",
}
)
normalize_config(resume_cfg)

train(cfg=resume_cfg, dataset_meta=dataset_meta)
assert resume_cfg.total_num_tokens == initial_total_num_tokens, (
f"total_num_tokens should be preserved on resume. "
f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}"
)

resume_dataset_meta = load_datasets(cfg=resume_cfg)

assert resume_cfg.total_num_tokens == initial_total_num_tokens, (
f"total_num_tokens should not be recalculated when resuming. "
f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}"
)

train(cfg=resume_cfg, dataset_meta=resume_dataset_meta)

assert resume_cfg.total_num_tokens == initial_total_num_tokens, (
f"total_num_tokens should remain unchanged after resume training. "
f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}"
)
check_model_output_exists(temp_dir, cfg)

tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
Expand Down
Loading