Skip to content

[Bug] GRPO OOM When Resuming From Checkpoint With Unsloth Standby #3302

@BenjaminBruenau

Description

@BenjaminBruenau
  1. Did you update? yes
  2. Colab or Kaggle or local / cloud: local
  3. Number GPUs used, use nvidia-smi: 1 (RTX 3090 24GB)
  4. Which notebook? Please link!: custom
  5. Which Unsloth version, TRL version, transformers version, PyTorch version?
unsloth                                  2025.9.2
unsloth-zoo                              2025.9.3
trl                                      0.22.2
transformers                             4.56.1
torch                                    2.7.1
triton                                   3.3.1
vllm                                     0.10.1.1
  1. Which trainer? GRPOTrainer
Code
import os
import sys
from typing import Optional

os.environ["UNSLOTH_VLLM_STANDBY"] = "1"
os.environ["UNSLOTH_ENABLE_LOGGING"] = "1"

os.environ["CUDA_VISIBLE_DEVICES"] = "1"

from unsloth import FastLanguageModel, is_bfloat16_supported


from trl import GRPOConfig, GRPOTrainer

print(f"[PID {os.getpid()}] Script start. Python version: {sys.version}", flush=True)
print(f"[PID {os.getpid()}] Current PWD: {os.getcwd()}", flush=True)

# --- Model Configuration ---
max_seq_length = 4000
lora_rank = 128
resume_from_checkpoint = True
output_directory = "outputs"
wandb_run_id = ""

# --- Model Loading ---
print(f"[PID {os.getpid()}, ] Loading model...", flush=True)
model_name = "outputs_reasoning/checkpoint-40"#"unsloth/Qwen3-4B-Instruct-2507"

model_kwargs = {
    "model_name": model_name,
    "max_seq_length": max_seq_length,
    "load_in_4bit": True,
    "fast_inference": True,
    "max_lora_rank": lora_rank,
    "gpu_memory_utilization": 0.6,
    #"attn_implementation": 'eager',
}

model, tokenizer = FastLanguageModel.from_pretrained(**model_kwargs)
print(f"[PID {os.getpid()}, ] Model loaded.", flush=True)

# --- LoRA Application ---
print(f"[PID {os.getpid()}, ] Applying LoRA...", flush=True)
model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=lora_rank*2,
    use_gradient_checkpointing="unsloth", 
    random_state=3407,
)
print(f"[PID {os.getpid()}, ] LoRA applied.", flush=True)


# --- Dataset Loading ---
print(f"[PID {os.getpid()}, ] Loading dataset...", flush=True)

dataset = ...
print(f"[PID {os.getpid()}, ] Dataset loaded.", flush=True)

print(f"[PID {os.getpid()}, ] Loading evaluation dataset...", flush=True)
evaluation_dataset = ...
print(f"[PID {os.getpid()}, ] Evaluation Dataset loaded.", flush=True)

# --- Reward Function Definitions ---
def correctness_reward_func_slow(prompts, completions, answer, **kwargs) -> list[float]:
    return [0.1 for completion in completions]

# --- Training Configuration ---
print(f"[PID {os.getpid()}, ] Setting up training...", flush=True)

from vllm import SamplingParams
vllm_sampling_params = SamplingParams(
    min_p = 0.1,
    top_p = 1.0,
    top_k = -1,
    seed = 3407,
    stop = [tokenizer.eos_token],
    include_stop_str_in_output = True,
)


training_args = GRPOConfig(
    #use_vllm=True,
    vllm_sampling_params = vllm_sampling_params,
    temperature = 1.0,
    learning_rate=5e-6, 
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="linear", # cosine
    optim="adamw_8bit",
    logging_steps=1,
    bf16=is_bfloat16_supported(),
    fp16=not is_bfloat16_supported(),
    per_device_train_batch_size=8,
    gradient_accumulation_steps=4,
    num_generations=8,
    max_prompt_length=2500,
    max_completion_length=1500,
    max_steps=2500,
    save_steps=20,
    max_grad_norm=0.1,
    report_to="wandb",
    output_dir=output_directory,

    fp16_full_eval = False,
    per_device_eval_batch_size = 16,
    eval_accumulation_steps = 1,
    eval_strategy = "steps",
    eval_steps = 100,
)

print(f"[PID {os.getpid()}, ] Training arguments set.", flush=True)


print(f"[PID {os.getpid()}, ] Preparing trainer...", flush=True)
if resume_from_checkpoint:
    print(f"[PID {os.getpid()}, ] Resuming from checkpoint...", flush=True)
    import wandb
    wandb.init(project='huggingface', resume='must', id=wandb_run_id)

trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[correctness_reward_func_slow],
    args=training_args,
    train_dataset=dataset,
    eval_dataset=evaluation_dataset,
)

print(f"[PID {os.getpid()}, ] Trainer prepared.", flush=True)



print(f"[PID {os.getpid()}, ] Starting training...", flush=True)
if resume_from_checkpoint:
    trainer.train(resume_from_checkpoint=True)
else:
    trainer.train()
Full Error Message and Traceback
INFO 09-10 12:12:34 [block_pool.py:280] Successfully reset prefix cache
INFO 09-10 12:12:34 [gpu_worker.py:115] Sleep mode freed 9.19 GiB memory, 7.66 GiB memory is still in use.
INFO 09-10 12:12:34 [executor_base.py:188] It took 0.498879 seconds to fall asleep.
CUDA Error: out of memory at /workspace/csrc/cumem_allocator.cpp:62
Traceback (most recent call last):
  File "/home/brr/dev/RL/packages/sql/src/sql/train_reasoning.py", line 303, in <module>
    trainer.train(resume_from_checkpoint=True)
  File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2328, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 325, in _fast_inner_training_loop
  File "<string>", line 34, in _unsloth_training_step
  File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/trl/extras/profiling.py", line 98, in wrapper
    return func(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brr/dev/RL/packages/sql/src/sql/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 1795, in _prepare_inputs
    self.llm.wake_up()
  File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 1572, in wake_up
    self.llm_engine.wake_up(tags)
  File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/vllm/v1/engine/llm_engine.py", line 285, in wake_up
    self.engine_core.wake_up(tags)
  File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/vllm/v1/engine/core_client.py", line 278, in wake_up
    self.engine_core.wake_up(tags)
  File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 374, in wake_up
    self.model_executor.wake_up(tags)
  File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/vllm/executor/executor_base.py", line 202, in wake_up
    self.collective_rpc("wake_up", kwargs=dict(tags=tags))
  File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py", line 58, in collective_rpc
    answer = run_method(self.driver_worker, method, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/vllm/utils/__init__.py", line 3007, in run_method
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 124, in wake_up
    allocator.wake_up(tags)
  File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/unsloth_zoo/vllm_utils.py", line 573, in wake_up
    create_and_map(handle)
  File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/vllm/device_allocator/cumem.py", line 78, in create_and_map
    python_create_and_map(*allocation_handle)
RuntimeError: CUDA Error: out of memory at /workspace/csrc/cumem_allocator.cpp:62

Hello there!

First of all I really enjoy the latest GRPO Update and the memory savings coming with it.
I was able to double my LoRA rank when training and I still see some leeway for another increase (around 70% VRAM usage with r=128 while training compared to ~90% with r=64 before)

However when I try to resume training from a checkpoint i quickly run into CUDA out of memory errors (I believe shortly before the first step completes).
Based on the traceback this might be related to the new Standby function done in conjunction with vllm (although the vram usage seems to be higher initially as opposed to when not resuming from a checkpoint).

I also tried retraining from a checkpoint (starting a new training run) and that works without any issues.

🦥

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions