From a1522b7c6580efc60844235b085af8034f6ae9b3 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 9 May 2026 08:24:02 +0530 Subject: [PATCH 1/6] activation vram leak --- examples/mistral4/qlora-text.yml | 1 + tests/e2e/test_activation_offloading.py | 54 ++++++++++++++++++++++++- 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/examples/mistral4/qlora-text.yml b/examples/mistral4/qlora-text.yml index 887ce6da09..4ec2388553 100644 --- a/examples/mistral4/qlora-text.yml +++ b/examples/mistral4/qlora-text.yml @@ -56,3 +56,4 @@ warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 +ddp_find_unused_parameters: true diff --git a/tests/e2e/test_activation_offloading.py b/tests/e2e/test_activation_offloading.py index 5715e68baa..bfff9dc9f3 100644 --- a/tests/e2e/test_activation_offloading.py +++ b/tests/e2e/test_activation_offloading.py @@ -25,6 +25,7 @@ def test_activation_offloading( self, temp_dir, adapter, + monkeypatch, ): cfg = DictDefault( { @@ -47,7 +48,7 @@ def test_activation_offloading( }, ], "num_epochs": 1, - "max_steps": 2, + "max_steps": 10, "micro_batch_size": 1, "gradient_accumulation_steps": 1, "output_dir": temp_dir, @@ -71,9 +72,60 @@ def test_activation_offloading( cfg["adapter"] = "qlora" cfg["load_in_4bit"] = True + # Record OffloadActivations state at the start of each training_step. + # Regression guard for #3638: tracker / dedup map / forward stash must + # be empty at the start of every step. With the leak (pre-fix), these + # grow monotonically and pin GPU memory until OOM. + from axolotl.core.trainers.mixins import activation_checkpointing as ac_mod + + recorded_states: list[dict] = [] + original_training_step = ( + ac_mod.ActivationOffloadingMixin.training_step + ) + + def recording_training_step(self, *args, **kwargs): + ctx = self.activation_offload_context + if isinstance(ctx, ac_mod.OffloadActivations): + recorded_states.append( + { + "step": self._offload_step_counter, + "tracker": len(ctx.tracker), + "storage_dedup": len(ctx.storage_to_tensor_id), + "fwd_stash": len(getattr(ctx, "fwd_stash", {})), + "bwd_tensor_stash": len( + getattr(ctx, "bwd_tensor_stash", {}) + ), + "bwd_ev_stash": len(getattr(ctx, "bwd_ev_stash", {})), + } + ) + return original_training_step(self, *args, **kwargs) + + monkeypatch.setattr( + ac_mod.ActivationOffloadingMixin, + "training_step", + recording_training_step, + ) + cfg = validate_config(cfg) normalize_config(cfg) dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) + + # All recorded pre-step states must be clean: cross-step state never + # carries over. + assert recorded_states, "no training_step recorded — test setup wrong" + for rec in recorded_states: + assert rec["tracker"] == 0, ( + f"OffloadActivations.tracker not empty at start of step " + f"{rec['step']}: {rec} — cross-step leak (#3638) regressed" + ) + assert rec["storage_dedup"] == 0, ( + f"OffloadActivations.storage_to_tensor_id not empty at start " + f"of step {rec['step']}: {rec}" + ) + assert rec["fwd_stash"] == 0, ( + f"OffloadActivations.fwd_stash not empty at start of step " + f"{rec['step']}: {rec}" + ) From 12d85da7ca4dd42cb414972a5176c019fddd8858 Mon Sep 17 00:00:00 2001 From: Your Name Date: Sat, 9 May 2026 09:01:56 +0530 Subject: [PATCH 2/6] lint --- tests/e2e/test_activation_offloading.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/e2e/test_activation_offloading.py b/tests/e2e/test_activation_offloading.py index bfff9dc9f3..b40e8a1ce2 100644 --- a/tests/e2e/test_activation_offloading.py +++ b/tests/e2e/test_activation_offloading.py @@ -79,9 +79,7 @@ def test_activation_offloading( from axolotl.core.trainers.mixins import activation_checkpointing as ac_mod recorded_states: list[dict] = [] - original_training_step = ( - ac_mod.ActivationOffloadingMixin.training_step - ) + original_training_step = ac_mod.ActivationOffloadingMixin.training_step def recording_training_step(self, *args, **kwargs): ctx = self.activation_offload_context @@ -92,9 +90,7 @@ def recording_training_step(self, *args, **kwargs): "tracker": len(ctx.tracker), "storage_dedup": len(ctx.storage_to_tensor_id), "fwd_stash": len(getattr(ctx, "fwd_stash", {})), - "bwd_tensor_stash": len( - getattr(ctx, "bwd_tensor_stash", {}) - ), + "bwd_tensor_stash": len(getattr(ctx, "bwd_tensor_stash", {})), "bwd_ev_stash": len(getattr(ctx, "bwd_ev_stash", {})), } ) From c6b0b604526991d36b512b27756a1065d94a078f Mon Sep 17 00:00:00 2001 From: Your Name Date: Mon, 11 May 2026 19:18:05 +0530 Subject: [PATCH 3/6] enter fix for vram leak --- .../mixins/activation_checkpointing.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/axolotl/core/trainers/mixins/activation_checkpointing.py b/src/axolotl/core/trainers/mixins/activation_checkpointing.py index b61c45feed..dd892bebce 100644 --- a/src/axolotl/core/trainers/mixins/activation_checkpointing.py +++ b/src/axolotl/core/trainers/mixins/activation_checkpointing.py @@ -22,6 +22,25 @@ LOG = get_logger(__name__) +# TODO(#3638): drop once TRL pin includes huggingface/trl#5730. Mirrors the +# upstream __enter__ override — clears cross-step state on context re-entry +# so saved tensors that never unpack during backward (MoE / torch.compile) +# don't accumulate as leaked GPU references. +def _axolotl_offload_enter(self): + self.tracker.clear() + self.storage_to_tensor_id.clear() + if self.use_streams: + self.fwd_stash.clear() + self.bwd_tensor_stash.clear() + self.bwd_ev_stash.clear() + self.is_first_forward_call = True + self.is_first_backward_call = True + return super(OffloadActivations, self).__enter__() + + +OffloadActivations.__enter__ = _axolotl_offload_enter + + class ActivationOffloadingMixin(Trainer): """ Trainer mixin class for activation checkpointing w offloading From 7343b96dc2ce5556d060976557def3b8794e1f2a Mon Sep 17 00:00:00 2001 From: VED <146507396+ved1beta@users.noreply.github.com> Date: Sat, 9 May 2026 19:52:54 +0530 Subject: [PATCH 4/6] Update examples/mistral4/qlora-text.yml --- examples/mistral4/qlora-text.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/mistral4/qlora-text.yml b/examples/mistral4/qlora-text.yml index 4ec2388553..887ce6da09 100644 --- a/examples/mistral4/qlora-text.yml +++ b/examples/mistral4/qlora-text.yml @@ -56,4 +56,3 @@ warmup_ratio: 0.1 evals_per_epoch: 1 saves_per_epoch: 1 weight_decay: 0.0 -ddp_find_unused_parameters: true From ca83bff7a137780ade82a0caeffa027f57e26cca Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 12 May 2026 18:11:08 +0530 Subject: [PATCH 5/6] undo --- tests/e2e/test_activation_offloading.py | 50 +------------------------ 1 file changed, 1 insertion(+), 49 deletions(-) diff --git a/tests/e2e/test_activation_offloading.py b/tests/e2e/test_activation_offloading.py index b40e8a1ce2..5715e68baa 100644 --- a/tests/e2e/test_activation_offloading.py +++ b/tests/e2e/test_activation_offloading.py @@ -25,7 +25,6 @@ def test_activation_offloading( self, temp_dir, adapter, - monkeypatch, ): cfg = DictDefault( { @@ -48,7 +47,7 @@ def test_activation_offloading( }, ], "num_epochs": 1, - "max_steps": 10, + "max_steps": 2, "micro_batch_size": 1, "gradient_accumulation_steps": 1, "output_dir": temp_dir, @@ -72,56 +71,9 @@ def test_activation_offloading( cfg["adapter"] = "qlora" cfg["load_in_4bit"] = True - # Record OffloadActivations state at the start of each training_step. - # Regression guard for #3638: tracker / dedup map / forward stash must - # be empty at the start of every step. With the leak (pre-fix), these - # grow monotonically and pin GPU memory until OOM. - from axolotl.core.trainers.mixins import activation_checkpointing as ac_mod - - recorded_states: list[dict] = [] - original_training_step = ac_mod.ActivationOffloadingMixin.training_step - - def recording_training_step(self, *args, **kwargs): - ctx = self.activation_offload_context - if isinstance(ctx, ac_mod.OffloadActivations): - recorded_states.append( - { - "step": self._offload_step_counter, - "tracker": len(ctx.tracker), - "storage_dedup": len(ctx.storage_to_tensor_id), - "fwd_stash": len(getattr(ctx, "fwd_stash", {})), - "bwd_tensor_stash": len(getattr(ctx, "bwd_tensor_stash", {})), - "bwd_ev_stash": len(getattr(ctx, "bwd_ev_stash", {})), - } - ) - return original_training_step(self, *args, **kwargs) - - monkeypatch.setattr( - ac_mod.ActivationOffloadingMixin, - "training_step", - recording_training_step, - ) - cfg = validate_config(cfg) normalize_config(cfg) dataset_meta = load_datasets(cfg=cfg) train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) - - # All recorded pre-step states must be clean: cross-step state never - # carries over. - assert recorded_states, "no training_step recorded — test setup wrong" - for rec in recorded_states: - assert rec["tracker"] == 0, ( - f"OffloadActivations.tracker not empty at start of step " - f"{rec['step']}: {rec} — cross-step leak (#3638) regressed" - ) - assert rec["storage_dedup"] == 0, ( - f"OffloadActivations.storage_to_tensor_id not empty at start " - f"of step {rec['step']}: {rec}" - ) - assert rec["fwd_stash"] == 0, ( - f"OffloadActivations.fwd_stash not empty at start of step " - f"{rec['step']}: {rec}" - ) From 3fc7020ddd2e19269d8c822bc0d4ae9ec7a9cbc1 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 12 May 2026 21:06:07 +0530 Subject: [PATCH 6/6] leak test --- tests/e2e/test_activation_offloading.py | 104 ++++++++++++++++++++++++ 1 file changed, 104 insertions(+) diff --git a/tests/e2e/test_activation_offloading.py b/tests/e2e/test_activation_offloading.py index 5715e68baa..ee3fdbb6cf 100644 --- a/tests/e2e/test_activation_offloading.py +++ b/tests/e2e/test_activation_offloading.py @@ -5,6 +5,9 @@ import pytest from axolotl.common.datasets import load_datasets +from axolotl.core.trainers.mixins.activation_checkpointing import ( + ActivationOffloadingMixin, +) from axolotl.train import train from axolotl.utils.config import normalize_config, validate_config from axolotl.utils.dict import DictDefault @@ -77,3 +80,104 @@ def test_activation_offloading( train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) + + def test_no_vram_leak_regression(self, temp_dir, monkeypatch): + """#3638 regression — fail on linear VRAM growth across training steps. + + The bug: ``OffloadActivations.__enter__`` doesn't clear cross-step + state, so a saved tensor that never unpacks during backward + (MoE / ``torch.compile``) sits in ``ctx.tracker`` forever — and its + GPU storage stays alive. Across many steps memory grows linearly. + + Tiny CI models won't exhibit the upstream MoE/compile unpack failure + on their own, so we *inject* the same leftover: after every step we + stash a small CUDA tensor into ``ctx.tracker``. The fix clears it on + the next ``__enter__`` (memory flat); without the fix it accumulates + (memory grows ~constant bytes/step). The fail mode is the bug's own + symptom — ``torch.cuda.memory_allocated`` increasing across steps. + """ + import torch + + if not torch.cuda.is_available(): + pytest.skip("VRAM-leak test requires CUDA") + + mem_per_step: list[int] = [] + seed_id = [10**9] + seed_bytes = 4 * 1024 * 1024 # 4 MB / step + + original_step = ActivationOffloadingMixin.training_step + + def wrapped_step(self, *args, **kwargs): + torch.cuda.synchronize() + mem_per_step.append(torch.cuda.memory_allocated()) + out = original_step(self, *args, **kwargs) + + # Inject the MoE-style leftover: a CUDA tensor stuck in + # OffloadActivations.tracker. The local `seed` ref dies on + # return — only ctx.tracker keeps it alive, so the next + # __enter__'s clear (with the fix) actually releases the GPU + # memory. Without the fix these accumulate step-over-step. + ctx = self.activation_offload_context + seed_id[0] += 1 + seed = torch.empty(seed_bytes // 2, dtype=torch.float16, device="cuda") + ctx.tracker[seed_id[0]] = (seed, False, None, None, None) + # Stop the next forward's pack_tensor from raising on its + # "tracker should have been cleared" guard. With the fix this + # flag gets reset by __enter__ anyway; on main it would + # otherwise crash before our VRAM measurement on step 2. + ctx.is_first_forward_call = False + return out + + monkeypatch.setattr(ActivationOffloadingMixin, "training_step", wrapped_step) + + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM2-135M", + "sequence_len": 1024, + "val_set_size": 0.0, + "special_tokens": {"pad_token": "<|endoftext|>"}, + "datasets": [ + {"path": "mhenrichsen/alpaca_2k_test", "type": "alpaca"}, + ], + "max_steps": 10, + "micro_batch_size": 1, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 1e-5, + "optimizer": "adamw_torch", + "lr_scheduler": "cosine", + "flash_attention": True, + "bf16": "auto", + "gradient_checkpointing": True, + "activation_offloading": True, + "save_first_step": False, + } + ) + cfg = validate_config(cfg) + normalize_config(cfg) + dataset_meta = load_datasets(cfg=cfg) + train(cfg=cfg, dataset_meta=dataset_meta) + + # Drop warm-up steps; allocator settling distorts early samples. + warmup = 3 + samples = mem_per_step[warmup:] + assert len(samples) >= 5, ( + f"need >= 5 post-warmup samples, got {len(samples)} " + f"(total {len(mem_per_step)})" + ) + + # Injection is 4 MB/step. With the fix __enter__ clears each seed + # before the next step → growth ≈ 0. Without the fix seeds pile up + # → growth ≈ 4 MB × (steps-1). 10 MB is well above allocator jitter + # and well below the leaky-build floor. + growth_mb = (samples[-1] - samples[0]) / (1024**2) + tolerance_mb = 10 + + per_step_mb = [round(m / 1024**2, 1) for m in mem_per_step] + assert growth_mb < tolerance_mb, ( + f"VRAM grew {growth_mb:.1f} MB across {len(samples)} post-warmup " + f"steps — linear-increase signature of the #3638 VRAM leak. " + f"Per-step memory_allocated (MB): {per_step_mb}" + ) + + check_model_output_exists(temp_dir, cfg)