Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 52 additions & 34 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def setup(
)
dataloader.load_state_dict(dataloader_state_dict)

print(f" ✓ Training dataloader loaded with {len(dataset)} samples")
print(f" ✓ Training dataloader loaded with {len(dataset)} samples", flush=True)

# Load validation dataset if provided
val_dataloader: Optional[StatefulDataLoader] = None
Expand All @@ -211,12 +211,15 @@ def setup(
shuffle=False,
collate_fn=rl_collate_fn,
)
print(f" ✓ Validation dataloader loaded with {len(val_dataset)} samples")
print(
f" ✓ Validation dataloader loaded with {len(val_dataset)} samples",
flush=True,
)

# ==========================
# Cluster
# ==========================
print("\n▶ Setting up compute cluster...")
print("\n▶ Setting up compute cluster...", flush=True)
colocated_inference = generation_config["colocated"]["enabled"]

if colocated_inference:
Expand All @@ -232,7 +235,10 @@ def setup(
)
train_cluster = cluster
inference_cluster = cluster
print(f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes")
print(
f" ✓ Ray cluster initialized with {cluster_config['num_nodes']} nodes",
flush=True,
)

else:
assert generation_config["backend"] != "megatron", (
Expand Down Expand Up @@ -288,7 +294,8 @@ def setup(
max_colocated_worker_groups=1,
)
print(
f" ✓ Ray train cluster initialized with {train_nodes} nodes with {train_gpus_per_node} GPUs per node"
f" ✓ Ray train cluster initialized with {train_nodes} nodes with {train_gpus_per_node} GPUs per node",
flush=True,
)

# initialize inference cluster
Expand All @@ -300,13 +307,14 @@ def setup(
max_colocated_worker_groups=1,
)
print(
f" ✓ Ray inference cluster initialized with {inference_nodes} nodes with {inference_gpus_per_node} GPUs per node"
f" ✓ Ray inference cluster initialized with {inference_nodes} nodes with {inference_gpus_per_node} GPUs per node",
flush=True,
)

# ==========================
# Training and Inference
# ==========================
print("\n▶ Setting up model and training...")
print("\n▶ Setting up model and training...", flush=True)

# vllm model loading prefers clean environment, initialize policy_generation before policy (#52 will fix this)
backend = generation_config["backend"]
Expand All @@ -315,7 +323,8 @@ def setup(
if backend == "megatron":
policy_generation = None
print(
f" ✓ Using {backend} backend for generation with {policy_config['model_name']}"
f" ✓ Using {backend} backend for generation with {policy_config['model_name']}",
flush=True,
)
elif backend == "vllm":
generation_config = cast(VllmConfig, generation_config)
Expand All @@ -331,7 +340,8 @@ def setup(
# vllm 0.8 fails in initialization if its called in the first training step since it has no clean view of the GPU memory (HF is sharing the same memory).
policy_generation.finish_generation()
print(
f" ✓ Using vLLM backend for generation with {policy_config['model_name']}"
f" ✓ Using vLLM backend for generation with {policy_config['model_name']}",
flush=True,
)

if last_checkpoint_path:
Expand All @@ -354,7 +364,7 @@ def setup(
# if it is not colocated inference, initialize collective communication for update weights
if not colocated_inference:
ip, port = train_cluster.get_master_address_and_port()
print(f"Using ip: {ip}, port: {port} for collective communication")
print(f"Using ip: {ip}, port: {port} for collective communication", flush=True)
# inference cluster + head node of the train cluster
world_size = inference_nodes * inference_gpus_per_node + 1
# init collective
Expand All @@ -371,7 +381,7 @@ def setup(

print("\n" + "=" * 60)
print(" " * 18 + "SETUP COMPLETE")
print("=" * 60 + "\n")
print("=" * 60 + "\n", flush=True)

return (
policy,
Expand Down Expand Up @@ -445,7 +455,8 @@ def refit_policy_generation(
)
total_num_keys = sum(len(k) for k in grouped_param_keys)
print(
f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups"
f"[Refit] Split {total_num_keys} keys into {len(grouped_param_keys)} groups",
flush=True,
)
# do update
for keys in grouped_param_keys:
Expand Down Expand Up @@ -524,7 +535,7 @@ def grpo_train(

# Run validation at the start if configured
if val_at_start and step == 0:
print("\n🔍 Running initial validation...")
print("\n🔍 Running initial validation...", flush=True)
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(policy, policy_generation, colocated_inference)
POLICY_GENERATION_STALE = False
Expand All @@ -546,7 +557,8 @@ def grpo_train(
batch: BatchedDataDict[DatumSpec]
for batch in dataloader:
print(
f"\n{'=' * 25} Step {step + 1}/{min(len(dataloader), master_config['grpo']['max_num_steps'])} {'=' * 25}"
f"\n{'=' * 25} Step {step + 1}/{min(len(dataloader), master_config['grpo']['max_num_steps'])} {'=' * 25}",
flush=True,
)
maybe_gpu_profile_step(policy, step + 1)
if policy != policy_generation:
Expand All @@ -555,7 +567,7 @@ def grpo_train(

with timer.time("total_step_time"):
# Prepare batch
print("▶ Preparing batch...")
print("▶ Preparing batch...", flush=True)
with timer.time("data_processing"):
# Repeat batch items
repeated_batch: BatchedDataDict[DatumSpec] = batch.repeat_interleave(
Expand All @@ -569,7 +581,10 @@ def grpo_train(
input_ids = batched_flat["token_ids"]

# Generate responses - this updates the LLMMessageLogType in repeated_batch
print(f"▶ Generating responses for batch of size {repeated_batch.size}...")
print(
f"▶ Generating responses for batch of size {repeated_batch.size}...",
flush=True,
)
with timer.time("prepare_for_generation"):
if NEED_REFIT and POLICY_GENERATION_STALE:
refit_policy_generation(
Expand Down Expand Up @@ -611,12 +626,12 @@ def grpo_train(
policy_generation.finish_generation()

# Calculate rewards & advantages
print("▶ Processing rewards...")
print("▶ Processing rewards...", flush=True)
with timer.time("reward_calculation"):
# Extract rewards from final_batch
rewards = repeated_batch["total_reward"]

print("▶ Computing advantages...")
print("▶ Computing advantages...", flush=True)
baseline, std = calculate_baseline_and_std_per_prompt(
input_ids,
rewards,
Expand Down Expand Up @@ -678,11 +693,11 @@ def grpo_train(
train_data.update(flat_messages.get_multimodal_dict(as_tensors=False))
train_data.to("cpu")

print("▶ Preparing for logprob inference...")
print("▶ Preparing for logprob inference...", flush=True)
with timer.time("logprob_inference_prep"):
policy.prepare_for_lp_inference()

print("▶ Computing logprobs...")
print("▶ Computing logprobs...", flush=True)
with timer.time("policy_and_reference_logprobs"):
fprop_logprobs = policy.get_logprobs(train_data)["logprobs"]
reference_logprobs = policy.get_reference_policy_logprobs(train_data)[
Expand All @@ -691,12 +706,12 @@ def grpo_train(
train_data["prev_logprobs"] = fprop_logprobs
train_data["reference_policy_logprobs"] = reference_logprobs

print("▶ Preparing for training...")
print("▶ Preparing for training...", flush=True)
with timer.time("training_prep"):
policy.prepare_for_training() # set model train and reload optim to GPU
POLICY_GENERATION_STALE = True

print("▶ Training policy...")
print("▶ Training policy...", flush=True)
with timer.time("policy_training"):
train_results = policy.train(train_data, loss_fn)

Expand Down Expand Up @@ -763,7 +778,7 @@ def grpo_train(
master_config["checkpointing"]["metric_name"] = None

with timer.time("checkpointing"):
print(f"Saving checkpoint for step {step + 1}...")
print(f"Saving checkpoint for step {step + 1}...", flush=True)
checkpoint_path = checkpointer.init_tmp_checkpoint(
step + 1, grpo_save_state, master_config
)
Expand Down Expand Up @@ -834,24 +849,27 @@ def grpo_train(
print(f" • Loss: {metrics['loss']:.4f}")
print(f" • Avg Reward: {np.mean(rewards.numpy()):.4f}")
print(
f" • Mean Generation Length: {rollout_metrics['mean_gen_tokens_per_sample']:.4f}"
f" • Mean Generation Length: {rollout_metrics['mean_gen_tokens_per_sample']:.4f}",
flush=True,
)
if "total_flops" in train_results:
total_tflops = (
train_results["total_flops"] / timing_metrics["policy_training"] / 1e12
)
num_ranks = train_results["num_ranks"]
print(
f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)"
f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)",
flush=True,
)
if "theoretical_tflops" in train_results:
theoretical_tflops = train_results["theoretical_tflops"]
print(
f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%"
f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%",
flush=True,
)
metrics["train_fp_utilization"] = total_tflops / theoretical_tflops

print("\n⏱️ Timing:")
print("\n⏱️ Timing:", flush=True)
# Display total time first, separately
total_time = timing_metrics.get("total_step_time", 0)

Expand All @@ -867,15 +885,15 @@ def grpo_train(
}
)

print(f" • Total step time: {total_time:.2f}s")
print(f" • Total step time: {total_time:.2f}s", flush=True)

# Display all other timing metrics
for k, v in sorted(
timing_metrics.items(), key=lambda item: item[1], reverse=True
):
if k != "total_step_time":
percent = (v / total_time * 100) if total_time > 0 else 0
print(f" • {k}: {v:.2f}s ({percent:.1f}%)")
print(f" • {k}: {v:.2f}s ({percent:.1f}%)", flush=True)

logger.log_metrics(metrics, step + 1, prefix="train")
logger.log_metrics(timing_metrics, step + 1, prefix="timing/train")
Expand All @@ -896,12 +914,12 @@ def validate(
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Run validation on the validation dataset."""
if val_dataloader is None:
print(" ⚠️ No validation dataloader provided, skipping validation")
print(" ⚠️ No validation dataloader provided, skipping validation", flush=True)
return {}, {}

timer = Timer()
with timer.time("total_validation_time"):
print(f"▶ Starting validation at step {step}...")
print(f"▶ Starting validation at step {step}...", flush=True)

total_rewards = []
total_lengths = []
Expand Down Expand Up @@ -974,7 +992,7 @@ def validate(
)
except Exception as e:
print(f"\n ⚠️ Error displaying message samples: {str(e)}")
print(" ⚠️ Continuing validation without displaying samples...")
print(" ⚠️ Continuing validation without displaying samples...", flush=True)

# Get timing metrics
timing_metrics = timer.get_timing_metrics(reduction_op="sum")
Expand All @@ -984,12 +1002,12 @@ def validate(
print("\n📊 Validation Results:")
print(f" • Accuracy: {accuracy:.4f}")
print(f" • Average response length: {avg_length:.1f} tokens")
print(f" • Samples processed: {len(total_rewards)}")
print(f" • Samples processed: {len(total_rewards)}", flush=True)

# Print timing information
print("\n ⏱️ Validation Timing:")
validation_time = timing_metrics.get("total_validation_time", 0)
print(f" • Total validation time: {validation_time:.2f}s")
print(f" • Total validation time: {validation_time:.2f}s", flush=True)

# Make sure to reset the timer after validation
timer.reset()
Expand Down