From 2a558f68059f7844174b439374aada732eb78704 Mon Sep 17 00:00:00 2001 From: Ved Date: Tue, 2 Dec 2025 11:00:48 +0530 Subject: [PATCH 01/10] compute loss only if training --- src/axolotl/core/trainers/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 7896c60889..8e848156a4 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -348,7 +348,7 @@ def compute_loss( # return (loss, outputs) if return_outputs else loss # track number of tokens for tokens per second calculation - if self.args.include_tkps: + if self.args.include_tkps and model.training: inputs_key = "labels" if "labels" in inputs else "input_ids" num_tokens = (inputs[inputs_key] != -100).sum() if is_distributed(): From 1a161a522f707cd62a9cac0e5acff1ce0aa52c43 Mon Sep 17 00:00:00 2001 From: Ved Date: Wed, 3 Dec 2025 13:20:49 +0530 Subject: [PATCH 02/10] save total_tokens for checkpiont --- src/axolotl/core/builders/causal.py | 4 ++- src/axolotl/core/trainers/base.py | 26 +++++++++++++-- .../utils/callbacks/tokens_per_second.py | 33 ++++++++++++++++++- 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 7a06431dc0..cda98087f0 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -72,7 +72,9 @@ def get_callbacks(self): if self.cfg.include_tkps: callbacks.append( TokensPerSecondCallback( - self.cfg.tensor_parallel_size, self.cfg.context_parallel_size + self.cfg.tensor_parallel_size, + self.cfg.context_parallel_size, + resume_from_checkpoint=self.cfg.resume_from_checkpoint, ) ) return callbacks diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 8e848156a4..d45201ec27 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json import os from collections import defaultdict from functools import partial, wraps @@ -49,6 +50,8 @@ LOG = get_logger(__name__) +TOKENS_STATE_FILE = "tokens_state.json" + REDUCTION_FNS = { "mean": torch.mean, "min": torch.min, @@ -363,9 +366,9 @@ def compute_loss( self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu() if hasattr(self.state, "total_tokens"): - self.state.total_tokens += num_tokens + self.state.total_tokens += num_tokens.cpu() else: - self.state.total_tokens = num_tokens + self.state.total_tokens = num_tokens.cpu() if self.args.orpo_alpha: return self.orpo_compute_loss( @@ -666,6 +669,25 @@ def _save_checkpoint(self, model, trial, **kwargs): run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) os.makedirs(output_dir, exist_ok=True) + + # Save total_tokens state if tracking is enabled + if self.args.include_tkps and hasattr(self.state, "total_tokens"): + tokens_state = { + "total_tokens": ( + int(self.state.total_tokens.item()) + if hasattr(self.state.total_tokens, "item") + else int(self.state.total_tokens) + ), + "num_tokens": ( + int(self.state.num_tokens.item()) + if hasattr(self.state.num_tokens, "item") + else int(self.state.num_tokens) + ), + } + tokens_state_path = os.path.join(output_dir, TOKENS_STATE_FILE) + with open(tokens_state_path, "w", encoding="utf-8") as f: + json.dump(tokens_state, f) + return super()._save_checkpoint(model, trial, **kwargs) # TODO(wing): remove once https://github.com/huggingface/transformers/pull/39866/files is merged diff --git a/src/axolotl/utils/callbacks/tokens_per_second.py b/src/axolotl/utils/callbacks/tokens_per_second.py index ead1292400..f602341a79 100644 --- a/src/axolotl/utils/callbacks/tokens_per_second.py +++ b/src/axolotl/utils/callbacks/tokens_per_second.py @@ -1,5 +1,7 @@ """A callback for calculating tokens per second during training.""" +import json +import os import time import torch @@ -10,22 +12,51 @@ TrainingArguments, ) +from axolotl.utils.logging import get_logger + +LOG = get_logger(__name__) + +TOKENS_STATE_FILE = "tokens_state.json" + class TokensPerSecondCallback(TrainerCallback): """ A callback to measure and log tokens per second during training. + Also handles saving/restoring total_tokens state across checkpoint resumes. """ - def __init__(self, tensor_parallel_size, context_parallel_size): + def __init__( + self, tensor_parallel_size, context_parallel_size, resume_from_checkpoint=None + ): super().__init__() self.step_time = 0.0 self.start_time = 0.0 self.non_data_parallel_size = 1 + self.resume_from_checkpoint = resume_from_checkpoint if tensor_parallel_size is not None: self.non_data_parallel_size *= tensor_parallel_size if context_parallel_size is not None: self.non_data_parallel_size *= context_parallel_size + def on_train_begin( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): # pylint: disable=unused-argument + """Restore total_tokens state when resuming from checkpoint.""" + if self.resume_from_checkpoint: + tokens_state_path = os.path.join( + self.resume_from_checkpoint, TOKENS_STATE_FILE + ) + if os.path.isfile(tokens_state_path): + with open(tokens_state_path, "r", encoding="utf-8") as f: + tokens_state = json.load(f) + state.total_tokens = torch.tensor(tokens_state.get("total_tokens", 0)) + state.num_tokens = torch.tensor(tokens_state.get("num_tokens", 0)) + LOG.info(f"Restored total_tokens: {state.total_tokens}") + def on_step_begin( self, args: TrainingArguments, From b62171b5dac72f4ab39f823c1b1bfda16cc4199c Mon Sep 17 00:00:00 2001 From: Ved Date: Wed, 3 Dec 2025 13:29:00 +0530 Subject: [PATCH 03/10] check if string --- .../utils/callbacks/tokens_per_second.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/src/axolotl/utils/callbacks/tokens_per_second.py b/src/axolotl/utils/callbacks/tokens_per_second.py index f602341a79..3be0deed01 100644 --- a/src/axolotl/utils/callbacks/tokens_per_second.py +++ b/src/axolotl/utils/callbacks/tokens_per_second.py @@ -46,16 +46,15 @@ def on_train_begin( **kwargs, ): # pylint: disable=unused-argument """Restore total_tokens state when resuming from checkpoint.""" - if self.resume_from_checkpoint: - tokens_state_path = os.path.join( - self.resume_from_checkpoint, TOKENS_STATE_FILE - ) - if os.path.isfile(tokens_state_path): - with open(tokens_state_path, "r", encoding="utf-8") as f: - tokens_state = json.load(f) - state.total_tokens = torch.tensor(tokens_state.get("total_tokens", 0)) - state.num_tokens = torch.tensor(tokens_state.get("num_tokens", 0)) - LOG.info(f"Restored total_tokens: {state.total_tokens}") + if not isinstance(self.resume_from_checkpoint, str): + return + tokens_state_path = os.path.join(self.resume_from_checkpoint, TOKENS_STATE_FILE) + if os.path.isfile(tokens_state_path): + with open(tokens_state_path, "r", encoding="utf-8") as f: + tokens_state = json.load(f) + state.total_tokens = torch.tensor(tokens_state.get("total_tokens", 0)) + state.num_tokens = torch.tensor(tokens_state.get("num_tokens", 0)) + LOG.info(f"Restored total_tokens: {state.total_tokens}") def on_step_begin( self, From 8f9c8ddfa24aa75902de777a3b01c9e046023b70 Mon Sep 17 00:00:00 2001 From: Ved Date: Tue, 9 Dec 2025 18:24:47 +0530 Subject: [PATCH 04/10] refactor total_tokens/ num_tokens --- src/axolotl/core/trainers/base.py | 52 ++++++++++--------- .../utils/callbacks/tokens_per_second.py | 21 +++++--- 2 files changed, 43 insertions(+), 30 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index d45201ec27..67b642f2b9 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -353,22 +353,28 @@ def compute_loss( # track number of tokens for tokens per second calculation if self.args.include_tkps and model.training: inputs_key = "labels" if "labels" in inputs else "input_ids" - num_tokens = (inputs[inputs_key] != -100).sum() + trainable_tokens = (inputs[inputs_key] != -100).sum() + total_tokens = inputs[inputs_key].numel() + if is_distributed(): torch.distributed.all_reduce( - num_tokens, op=torch.distributed.ReduceOp.SUM + trainable_tokens, op=torch.distributed.ReduceOp.SUM ) - if hasattr(self.state, "num_tokens"): - self.state.num_tokens = ( - self.state.num_tokens + (inputs[inputs_key] != -100).sum().cpu() + torch.distributed.all_reduce( + total_tokens, op=torch.distributed.ReduceOp.SUM ) - else: - self.state.num_tokens = (inputs[inputs_key] != -100).sum().cpu() - if hasattr(self.state, "total_tokens"): - self.state.total_tokens += num_tokens.cpu() - else: - self.state.total_tokens = num_tokens.cpu() + if not hasattr(self.state, "tokens"): + self.state.tokens = { + "trainable": torch.zeros(1), + "total": torch.zeros(1), + } + + # trainable tokens for throughput and total token slots for summaries + self.state.tokens["trainable"] = trainable_tokens.detach().cpu() + self.state.tokens["total"] = ( + self.state.tokens["total"] + torch.as_tensor(total_tokens).cpu() + ) if self.args.orpo_alpha: return self.orpo_compute_loss( @@ -628,13 +634,17 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: except (ValueError, TypeError, FileNotFoundError): pass - if self.args.include_tkps and train_eval == "train": + if ( + self.args.include_tkps + and train_eval == "train" + and hasattr(self.state, "tokens") + ): # each rank will log its own tokens per second # for logging_steps > 1 we obtain a moving average of this metric - logs["tokens_per_second_per_gpu"] = round( + logs["tokens/trainable_per_second_per_gpu"] = round( self.state.last_tokens_per_second.item() / self.args.logging_steps, 2 ) - logs["total_tokens"] = int(self.state.total_tokens.item()) + logs["tokens/total"] = int(self.state.tokens["total"].item()) del self._stored_metrics[train_eval] @@ -671,17 +681,11 @@ def _save_checkpoint(self, model, trial, **kwargs): os.makedirs(output_dir, exist_ok=True) # Save total_tokens state if tracking is enabled - if self.args.include_tkps and hasattr(self.state, "total_tokens"): + if self.args.include_tkps and hasattr(self.state, "tokens"): tokens_state = { - "total_tokens": ( - int(self.state.total_tokens.item()) - if hasattr(self.state.total_tokens, "item") - else int(self.state.total_tokens) - ), - "num_tokens": ( - int(self.state.num_tokens.item()) - if hasattr(self.state.num_tokens, "item") - else int(self.state.num_tokens) + "total": int(torch.as_tensor(self.state.tokens.get("total", 0)).item()), + "trainable": int( + torch.as_tensor(self.state.tokens.get("trainable", 0)).item() ), } tokens_state_path = os.path.join(output_dir, TOKENS_STATE_FILE) diff --git a/src/axolotl/utils/callbacks/tokens_per_second.py b/src/axolotl/utils/callbacks/tokens_per_second.py index 3be0deed01..1c02a8b5fd 100644 --- a/src/axolotl/utils/callbacks/tokens_per_second.py +++ b/src/axolotl/utils/callbacks/tokens_per_second.py @@ -52,9 +52,11 @@ def on_train_begin( if os.path.isfile(tokens_state_path): with open(tokens_state_path, "r", encoding="utf-8") as f: tokens_state = json.load(f) - state.total_tokens = torch.tensor(tokens_state.get("total_tokens", 0)) - state.num_tokens = torch.tensor(tokens_state.get("num_tokens", 0)) - LOG.info(f"Restored total_tokens: {state.total_tokens}") + state.tokens = { + "total": torch.tensor(tokens_state.get("total", 0)), + "trainable": torch.tensor(tokens_state.get("trainable", 0)), + } + LOG.info(f"Restored total_tokens: {state.tokens['total']}") def on_step_begin( self, @@ -63,6 +65,10 @@ def on_step_begin( control: TrainerControl, **kwargs, ): # pylint: disable=unused-argument + if not hasattr(state, "tokens"): + state.tokens = {"trainable": torch.zeros(1), "total": torch.zeros(1)} + else: + state.tokens["trainable"] = torch.zeros_like(state.tokens["trainable"]) self.start_time = time.perf_counter() state.last_tokens_per_second = torch.zeros(1) @@ -73,9 +79,10 @@ def on_step_end( control: TrainerControl, **kwargs, ): # pylint: disable=unused-argument - if hasattr(state, "num_tokens"): + tokens = getattr(state, "tokens", None) + if tokens and "trainable" in tokens: step_time = time.perf_counter() - self.start_time - num_tokens_per_device = state.num_tokens.clone() + num_tokens_per_device = tokens["trainable"].clone() # non data parallel groups have duplicated tokens, so we avoid double-counting num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size state.last_tokens_per_second = num_tokens_per_device / step_time @@ -91,4 +98,6 @@ def on_log( # after logging, clear the running metrics if hasattr(state, "last_tokens_per_second"): state.last_tokens_per_second.zero_() - state.num_tokens = torch.zeros(1) + tokens = getattr(state, "tokens", None) + if tokens and "trainable" in tokens: + tokens["trainable"] = torch.zeros_like(tokens["trainable"]) From bd8a9ce233f054e3613531cfc04b5d5e9ccdfd0c Mon Sep 17 00:00:00 2001 From: Ved Date: Tue, 9 Dec 2025 21:27:17 +0530 Subject: [PATCH 05/10] refactor 2 --- src/axolotl/core/trainers/base.py | 6 +++++- .../utils/callbacks/tokens_per_second.py | 17 ++++++++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 67b642f2b9..8992c74dc9 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -371,10 +371,14 @@ def compute_loss( } # trainable tokens for throughput and total token slots for summaries - self.state.tokens["trainable"] = trainable_tokens.detach().cpu() + self.state.tokens["trainable"] = ( + self.state.tokens["trainable"] + trainable_tokens.detach().cpu() + ) self.state.tokens["total"] = ( self.state.tokens["total"] + torch.as_tensor(total_tokens).cpu() ) + # Store per-step trainable tokens for throughput calculation + self.state.tokens["trainable_step"] = trainable_tokens.detach().cpu() if self.args.orpo_alpha: return self.orpo_compute_loss( diff --git a/src/axolotl/utils/callbacks/tokens_per_second.py b/src/axolotl/utils/callbacks/tokens_per_second.py index 1c02a8b5fd..c21dd9093a 100644 --- a/src/axolotl/utils/callbacks/tokens_per_second.py +++ b/src/axolotl/utils/callbacks/tokens_per_second.py @@ -67,8 +67,6 @@ def on_step_begin( ): # pylint: disable=unused-argument if not hasattr(state, "tokens"): state.tokens = {"trainable": torch.zeros(1), "total": torch.zeros(1)} - else: - state.tokens["trainable"] = torch.zeros_like(state.tokens["trainable"]) self.start_time = time.perf_counter() state.last_tokens_per_second = torch.zeros(1) @@ -80,9 +78,9 @@ def on_step_end( **kwargs, ): # pylint: disable=unused-argument tokens = getattr(state, "tokens", None) - if tokens and "trainable" in tokens: + if tokens and "trainable_step" in tokens: step_time = time.perf_counter() - self.start_time - num_tokens_per_device = tokens["trainable"].clone() + num_tokens_per_device = tokens["trainable_step"].clone() # non data parallel groups have duplicated tokens, so we avoid double-counting num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size state.last_tokens_per_second = num_tokens_per_device / step_time @@ -97,7 +95,16 @@ def on_log( ): # pylint: disable=unused-argument # after logging, clear the running metrics if hasattr(state, "last_tokens_per_second"): + logs["tokens/trainable_per_second_per_gpu"] = ( + state.last_tokens_per_second.item() + ) state.last_tokens_per_second.zero_() tokens = getattr(state, "tokens", None) + if tokens and "trainable_step" in tokens: + tokens["trainable_step"] = torch.zeros_like(tokens["trainable_step"]) + + if tokens and "total" in tokens: + logs["tokens/total"] = tokens["total"].item() + if tokens and "trainable" in tokens: - tokens["trainable"] = torch.zeros_like(tokens["trainable"]) + logs["tokens/trainable"] = tokens["trainable"].item() From 5ea134456b3659612c0619623308b3d42b11b8f9 Mon Sep 17 00:00:00 2001 From: Ved Date: Tue, 16 Dec 2025 11:44:57 +0530 Subject: [PATCH 06/10] rplc trainable_step/trian_per_sec_per_gpu --- src/axolotl/core/trainers/base.py | 5 +++-- src/axolotl/utils/callbacks/tokens_per_second.py | 11 ++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 8992c74dc9..6555cfa542 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -378,7 +378,8 @@ def compute_loss( self.state.tokens["total"] + torch.as_tensor(total_tokens).cpu() ) # Store per-step trainable tokens for throughput calculation - self.state.tokens["trainable_step"] = trainable_tokens.detach().cpu() + self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu() + if self.args.orpo_alpha: return self.orpo_compute_loss( @@ -645,7 +646,7 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: ): # each rank will log its own tokens per second # for logging_steps > 1 we obtain a moving average of this metric - logs["tokens/trainable_per_second_per_gpu"] = round( + logs["tokens/train_per_sec_per_gpu"] = round( self.state.last_tokens_per_second.item() / self.args.logging_steps, 2 ) logs["tokens/total"] = int(self.state.tokens["total"].item()) diff --git a/src/axolotl/utils/callbacks/tokens_per_second.py b/src/axolotl/utils/callbacks/tokens_per_second.py index c21dd9093a..0af305f9f4 100644 --- a/src/axolotl/utils/callbacks/tokens_per_second.py +++ b/src/axolotl/utils/callbacks/tokens_per_second.py @@ -78,9 +78,9 @@ def on_step_end( **kwargs, ): # pylint: disable=unused-argument tokens = getattr(state, "tokens", None) - if tokens and "trainable_step" in tokens: + if tokens and "trainable_tokens" in tokens: step_time = time.perf_counter() - self.start_time - num_tokens_per_device = tokens["trainable_step"].clone() + num_tokens_per_device = tokens["trainable_tokens"].clone() # non data parallel groups have duplicated tokens, so we avoid double-counting num_tokens_per_device = num_tokens_per_device / self.non_data_parallel_size state.last_tokens_per_second = num_tokens_per_device / step_time @@ -95,13 +95,14 @@ def on_log( ): # pylint: disable=unused-argument # after logging, clear the running metrics if hasattr(state, "last_tokens_per_second"): - logs["tokens/trainable_per_second_per_gpu"] = ( + logs["tokens/train_per_sec_per_gpu"] = ( state.last_tokens_per_second.item() ) state.last_tokens_per_second.zero_() tokens = getattr(state, "tokens", None) - if tokens and "trainable_step" in tokens: - tokens["trainable_step"] = torch.zeros_like(tokens["trainable_step"]) + # Clear per-step tokens after logging + if tokens and "trainable_tokens" in tokens: + tokens["trainable_tokens"] = torch.zeros_like(tokens["trainable_tokens"]) if tokens and "total" in tokens: logs["tokens/total"] = tokens["total"].item() From 13761fcb208b9d770dff42c1e0db357f575765e6 Mon Sep 17 00:00:00 2001 From: Ved Date: Tue, 16 Dec 2025 11:52:00 +0530 Subject: [PATCH 07/10] lint + log trainable/tokens --- src/axolotl/core/trainers/base.py | 2 +- src/axolotl/utils/callbacks/tokens_per_second.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 6555cfa542..b534122549 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -380,7 +380,6 @@ def compute_loss( # Store per-step trainable tokens for throughput calculation self.state.tokens["trainable_tokens"] = trainable_tokens.detach().cpu() - if self.args.orpo_alpha: return self.orpo_compute_loss( model, @@ -650,6 +649,7 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: self.state.last_tokens_per_second.item() / self.args.logging_steps, 2 ) logs["tokens/total"] = int(self.state.tokens["total"].item()) + logs["tokens/trainable"] = int(self.state.tokens["trainable"].item()) del self._stored_metrics[train_eval] diff --git a/src/axolotl/utils/callbacks/tokens_per_second.py b/src/axolotl/utils/callbacks/tokens_per_second.py index 0af305f9f4..a1b955a747 100644 --- a/src/axolotl/utils/callbacks/tokens_per_second.py +++ b/src/axolotl/utils/callbacks/tokens_per_second.py @@ -95,9 +95,7 @@ def on_log( ): # pylint: disable=unused-argument # after logging, clear the running metrics if hasattr(state, "last_tokens_per_second"): - logs["tokens/train_per_sec_per_gpu"] = ( - state.last_tokens_per_second.item() - ) + logs["tokens/train_per_sec_per_gpu"] = state.last_tokens_per_second.item() state.last_tokens_per_second.zero_() tokens = getattr(state, "tokens", None) # Clear per-step tokens after logging From 79f247de28fcef5d63e824924ca603d2846cc2c9 Mon Sep 17 00:00:00 2001 From: Ved Date: Fri, 19 Dec 2025 15:03:35 +0530 Subject: [PATCH 08/10] consolidate it in the callback. --- src/axolotl/core/trainers/base.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index b534122549..b720543bb7 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -648,9 +648,6 @@ def log(self, logs: dict[str, float], start_time: float | None = None) -> None: logs["tokens/train_per_sec_per_gpu"] = round( self.state.last_tokens_per_second.item() / self.args.logging_steps, 2 ) - logs["tokens/total"] = int(self.state.tokens["total"].item()) - logs["tokens/trainable"] = int(self.state.tokens["trainable"].item()) - del self._stored_metrics[train_eval] return super().log(logs, start_time) From a7d6b7f282e7e7f1573d51878766670eb2ffe8ef Mon Sep 17 00:00:00 2001 From: Ved Date: Wed, 24 Dec 2025 01:08:10 +0530 Subject: [PATCH 09/10] test for total_tokes aftr remuse --- tests/e2e/patched/test_resume.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py index 747b79dc7c..5d8a7083e1 100644 --- a/tests/e2e/patched/test_resume.py +++ b/tests/e2e/patched/test_resume.py @@ -68,6 +68,11 @@ def test_resume_lora_packed(self, temp_dir): normalize_config(cfg) dataset_meta = load_datasets(cfg=cfg) + initial_total_num_tokens = cfg.total_num_tokens + assert initial_total_num_tokens is not None, ( + "total_num_tokens should be calculated during load_datasets" + ) + train(cfg=cfg, dataset_meta=dataset_meta) resume_cfg = cfg | DictDefault( @@ -77,7 +82,24 @@ def test_resume_lora_packed(self, temp_dir): ) normalize_config(resume_cfg) - train(cfg=resume_cfg, dataset_meta=dataset_meta) + assert resume_cfg.total_num_tokens == initial_total_num_tokens, ( + f"total_num_tokens should be preserved on resume. " + f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}" + ) + + resume_dataset_meta = load_datasets(cfg=resume_cfg) + + assert resume_cfg.total_num_tokens == initial_total_num_tokens, ( + f"total_num_tokens should not be recalculated when resuming. " + f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}" + ) + + train(cfg=resume_cfg, dataset_meta=resume_dataset_meta) + + assert resume_cfg.total_num_tokens == initial_total_num_tokens, ( + f"total_num_tokens should remain unchanged after resume training. " + f"Expected {initial_total_num_tokens}, got {resume_cfg.total_num_tokens}" + ) check_model_output_exists(temp_dir, cfg) tb_log_path_1 = most_recent_subdir(temp_dir + "/runs") From 8f3f6c454ec98af65ad1cd7b54577b0822747221 Mon Sep 17 00:00:00 2001 From: Ved Date: Thu, 25 Dec 2025 17:06:07 +0530 Subject: [PATCH 10/10] check if tokenstate exist after ckpt --- tests/e2e/patched/test_resume.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/e2e/patched/test_resume.py b/tests/e2e/patched/test_resume.py index 5d8a7083e1..e6240f2088 100644 --- a/tests/e2e/patched/test_resume.py +++ b/tests/e2e/patched/test_resume.py @@ -2,6 +2,7 @@ E2E tests for resuming training """ +import os import re import subprocess @@ -9,6 +10,7 @@ from axolotl.common.datasets import load_datasets from axolotl.train import train +from axolotl.utils.callbacks.tokens_per_second import TOKENS_STATE_FILE from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault @@ -58,6 +60,7 @@ def test_resume_lora_packed(self, temp_dir): "use_tensorboard": True, "save_safetensors": True, "save_first_step": False, + "include_tkps": True, } ) if is_torch_bf16_gpu_available(): @@ -75,6 +78,12 @@ def test_resume_lora_packed(self, temp_dir): train(cfg=cfg, dataset_meta=dataset_meta) + checkpoint_path = f"{temp_dir}/checkpoint-9" + tokens_state_path = os.path.join(checkpoint_path, TOKENS_STATE_FILE) + assert os.path.isfile(tokens_state_path), ( + f"{TOKENS_STATE_FILE} should exist in checkpoint at {tokens_state_path}" + ) + resume_cfg = cfg | DictDefault( { "resume_from_checkpoint": f"{temp_dir}/checkpoint-9/",