-
-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
- Did you update? yes
Colab
orKaggle
or local / cloud: local- Number GPUs used, use
nvidia-smi
: 1 (RTX 3090 24GB) - Which notebook? Please link!: custom
- Which Unsloth version, TRL version, transformers version, PyTorch version?
unsloth 2025.9.2
unsloth-zoo 2025.9.3
trl 0.22.2
transformers 4.56.1
torch 2.7.1
triton 3.3.1
vllm 0.10.1.1
- Which trainer?
GRPOTrainer
Code
import os
import sys
from typing import Optional
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"
os.environ["UNSLOTH_ENABLE_LOGGING"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from unsloth import FastLanguageModel, is_bfloat16_supported
from trl import GRPOConfig, GRPOTrainer
print(f"[PID {os.getpid()}] Script start. Python version: {sys.version}", flush=True)
print(f"[PID {os.getpid()}] Current PWD: {os.getcwd()}", flush=True)
# --- Model Configuration ---
max_seq_length = 4000
lora_rank = 128
resume_from_checkpoint = True
output_directory = "outputs"
wandb_run_id = ""
# --- Model Loading ---
print(f"[PID {os.getpid()}, ] Loading model...", flush=True)
model_name = "outputs_reasoning/checkpoint-40"#"unsloth/Qwen3-4B-Instruct-2507"
model_kwargs = {
"model_name": model_name,
"max_seq_length": max_seq_length,
"load_in_4bit": True,
"fast_inference": True,
"max_lora_rank": lora_rank,
"gpu_memory_utilization": 0.6,
#"attn_implementation": 'eager',
}
model, tokenizer = FastLanguageModel.from_pretrained(**model_kwargs)
print(f"[PID {os.getpid()}, ] Model loaded.", flush=True)
# --- LoRA Application ---
print(f"[PID {os.getpid()}, ] Applying LoRA...", flush=True)
model = FastLanguageModel.get_peft_model(
model,
r=lora_rank,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha=lora_rank*2,
use_gradient_checkpointing="unsloth",
random_state=3407,
)
print(f"[PID {os.getpid()}, ] LoRA applied.", flush=True)
# --- Dataset Loading ---
print(f"[PID {os.getpid()}, ] Loading dataset...", flush=True)
dataset = ...
print(f"[PID {os.getpid()}, ] Dataset loaded.", flush=True)
print(f"[PID {os.getpid()}, ] Loading evaluation dataset...", flush=True)
evaluation_dataset = ...
print(f"[PID {os.getpid()}, ] Evaluation Dataset loaded.", flush=True)
# --- Reward Function Definitions ---
def correctness_reward_func_slow(prompts, completions, answer, **kwargs) -> list[float]:
return [0.1 for completion in completions]
# --- Training Configuration ---
print(f"[PID {os.getpid()}, ] Setting up training...", flush=True)
from vllm import SamplingParams
vllm_sampling_params = SamplingParams(
min_p = 0.1,
top_p = 1.0,
top_k = -1,
seed = 3407,
stop = [tokenizer.eos_token],
include_stop_str_in_output = True,
)
training_args = GRPOConfig(
#use_vllm=True,
vllm_sampling_params = vllm_sampling_params,
temperature = 1.0,
learning_rate=5e-6,
adam_beta1=0.9,
adam_beta2=0.99,
weight_decay=0.1,
warmup_ratio=0.1,
lr_scheduler_type="linear", # cosine
optim="adamw_8bit",
logging_steps=1,
bf16=is_bfloat16_supported(),
fp16=not is_bfloat16_supported(),
per_device_train_batch_size=8,
gradient_accumulation_steps=4,
num_generations=8,
max_prompt_length=2500,
max_completion_length=1500,
max_steps=2500,
save_steps=20,
max_grad_norm=0.1,
report_to="wandb",
output_dir=output_directory,
fp16_full_eval = False,
per_device_eval_batch_size = 16,
eval_accumulation_steps = 1,
eval_strategy = "steps",
eval_steps = 100,
)
print(f"[PID {os.getpid()}, ] Training arguments set.", flush=True)
print(f"[PID {os.getpid()}, ] Preparing trainer...", flush=True)
if resume_from_checkpoint:
print(f"[PID {os.getpid()}, ] Resuming from checkpoint...", flush=True)
import wandb
wandb.init(project='huggingface', resume='must', id=wandb_run_id)
trainer = GRPOTrainer(
model=model,
processing_class=tokenizer,
reward_funcs=[correctness_reward_func_slow],
args=training_args,
train_dataset=dataset,
eval_dataset=evaluation_dataset,
)
print(f"[PID {os.getpid()}, ] Trainer prepared.", flush=True)
print(f"[PID {os.getpid()}, ] Starting training...", flush=True)
if resume_from_checkpoint:
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
Full Error Message and Traceback
INFO 09-10 12:12:34 [block_pool.py:280] Successfully reset prefix cache
INFO 09-10 12:12:34 [gpu_worker.py:115] Sleep mode freed 9.19 GiB memory, 7.66 GiB memory is still in use.
INFO 09-10 12:12:34 [executor_base.py:188] It took 0.498879 seconds to fall asleep.
CUDA Error: out of memory at /workspace/csrc/cumem_allocator.cpp:62
Traceback (most recent call last):
File "/home/brr/dev/RL/packages/sql/src/sql/train_reasoning.py", line 303, in <module>
trainer.train(resume_from_checkpoint=True)
File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/transformers/trainer.py", line 2328, in train
return inner_training_loop(
^^^^^^^^^^^^^^^^^^^^
File "<string>", line 325, in _fast_inner_training_loop
File "<string>", line 34, in _unsloth_training_step
File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/trl/extras/profiling.py", line 98, in wrapper
return func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brr/dev/RL/packages/sql/src/sql/unsloth_compiled_cache/UnslothGRPOTrainer.py", line 1795, in _prepare_inputs
self.llm.wake_up()
File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/vllm/entrypoints/llm.py", line 1572, in wake_up
self.llm_engine.wake_up(tags)
File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/vllm/v1/engine/llm_engine.py", line 285, in wake_up
self.engine_core.wake_up(tags)
File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/vllm/v1/engine/core_client.py", line 278, in wake_up
self.engine_core.wake_up(tags)
File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/vllm/v1/engine/core.py", line 374, in wake_up
self.model_executor.wake_up(tags)
File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/vllm/executor/executor_base.py", line 202, in wake_up
self.collective_rpc("wake_up", kwargs=dict(tags=tags))
File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/vllm/executor/uniproc_executor.py", line 58, in collective_rpc
answer = run_method(self.driver_worker, method, args, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/vllm/utils/__init__.py", line 3007, in run_method
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 124, in wake_up
allocator.wake_up(tags)
File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/unsloth_zoo/vllm_utils.py", line 573, in wake_up
create_and_map(handle)
File "/home/brr/dev/RL/.venv/lib/python3.12/site-packages/vllm/device_allocator/cumem.py", line 78, in create_and_map
python_create_and_map(*allocation_handle)
RuntimeError: CUDA Error: out of memory at /workspace/csrc/cumem_allocator.cpp:62
Hello there!
First of all I really enjoy the latest GRPO Update and the memory savings coming with it.
I was able to double my LoRA rank when training and I still see some leeway for another increase (around 70% VRAM usage with r=128 while training compared to ~90% with r=64 before)
However when I try to resume training from a checkpoint i quickly run into CUDA out of memory errors (I believe shortly before the first step completes).
Based on the traceback this might be related to the new Standby function done in conjunction with vllm (although the vram usage seems to be higher initially as opposed to when not resuming from a checkpoint).
I also tried retraining from a checkpoint (starting a new training run) and that works without any issues.
🦥