diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index 190f3c2921..6e6869bd21 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -32,7 +32,11 @@ ClippedPGLossDataDict, ClippedPGLossFn, ) -from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt, set_seed +from nemo_rl.algorithms.utils import ( + calculate_baseline_and_std_per_prompt, + print_performance_metrics, + set_seed, +) from nemo_rl.data import DataConfig from nemo_rl.data.collate_fn import rl_collate_fn from nemo_rl.data.datasets import AllTaskProcessedDataset @@ -954,7 +958,6 @@ def grpo_train( total_steps + 1, name="train/token_mult_prob_error_plot_sample", ) - print("\n📊 Training Results:") print(f" • Loss: {metrics['loss']:.4f}") @@ -963,40 +966,19 @@ def grpo_train( f" • Mean Generation Length: {rollout_metrics['mean_gen_tokens_per_sample']:.4f}", flush=True, ) - if "total_flops" in train_results: - total_tflops = ( - train_results["total_flops"] - / timing_metrics["policy_training"] - / 1e12 - ) - num_ranks = train_results["num_ranks"] - print( - f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)", - flush=True, - ) - if "theoretical_tflops" in train_results: - theoretical_tflops = train_results["theoretical_tflops"] - print( - f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%", - flush=True, - ) - metrics["train_fp_utilization"] = total_tflops / theoretical_tflops print("\n⏱️ Timing:", flush=True) # Display total time first, separately total_time = timing_metrics.get("total_step_time", 0) + number_of_samples_per_step = ( + master_config["grpo"]["num_prompts_per_step"] + * master_config["grpo"]["num_generations_per_prompt"] + ) total_num_gpus = ( master_config["cluster"]["num_nodes"] * master_config["cluster"]["gpus_per_node"] ) - metrics.update( - { - "tokens_per_sec_per_gpu": metrics["total_num_tokens"] - / total_time - / total_num_gpus - } - ) print(f" • Total step time: {total_time:.2f}s", flush=True) @@ -1008,7 +990,14 @@ def grpo_train( percent = (v / total_time * 100) if total_time > 0 else 0 print(f" • {k}: {v:.2f}s ({percent:.1f}%)", flush=True) + performance_metrics = print_performance_metrics( + train_results, metrics, timing_metrics, master_config + ) + logger.log_metrics(metrics, total_steps + 1, prefix="train") + logger.log_metrics( + performance_metrics, total_steps + 1, prefix="performance" + ) logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train") timer.reset() @@ -1446,6 +1435,7 @@ def async_grpo_train( for t in trajectories: for k, v in t["rollout_metrics"].items(): rollout_metrics.setdefault(k, []).append(v) + # TODO: this simple averaging might cause misleading information for such data as max_gen_tokens, etc. rollout_metrics = { k: (sum(v) / len(v) if isinstance(v[0], (int, float)) else v) for k, v in rollout_metrics.items() @@ -1711,6 +1701,8 @@ def async_grpo_train( "loss": train_results["loss"].numpy(), "reward": rewards.numpy(), "grad_norm": train_results["grad_norm"].numpy(), + "mean_prompt_length": repeated_batch["length"].numpy(), + "total_num_tokens": input_lengths.numpy(), } metrics.update(train_results["all_mb_metrics"]) for k, v in metrics.items(): @@ -1751,6 +1743,14 @@ def async_grpo_train( percent = (v / total_time * 100) if total_time > 0 else 0 print(f" • {k}: {v:.2f}s ({percent:.1f}%)") + performance_metrics = print_performance_metrics( + train_results, metrics, timing_metrics, master_config + ) + + if "per_worker_token_counts" in metrics: + del metrics["per_worker_token_counts"] + + logger.log_metrics(performance_metrics, step + 1, prefix="performance") logger.log_metrics(metrics, step + 1, prefix="train") logger.log_metrics(timing_metrics, step + 1, prefix="timing/train") diff --git a/nemo_rl/algorithms/utils.py b/nemo_rl/algorithms/utils.py index fba7aafb26..e15d731efb 100644 --- a/nemo_rl/algorithms/utils.py +++ b/nemo_rl/algorithms/utils.py @@ -345,3 +345,233 @@ def maybe_pad_last_batch(batch: dict, dp_size: int, mbs: int) -> dict: ] ) return batch + + +def print_performance_metrics( + train_results: dict[str, float], + metrics: dict[str, float], + timing_metrics: dict[str, float], + master_config: dict, +) -> dict[str, float]: + """Print performance metrics for GRPO.""" + + # ===================================================== + # Generate Token Imbalance Visualization + # ===================================================== + def visualize_per_worker_load(per_worker_token_counts: dict[int, int]) -> float: + per_worker_token_counts_list = [ + v for k, v in sorted(per_worker_token_counts.items()) + ] + per_worker_load_ratio = [ + v / max(per_worker_token_counts_list) for v in per_worker_token_counts_list + ] + max_rows_to_print = 100 + print(" • Visualizing Token Imbalance per Generation Worker:") + for i in range(min(len(per_worker_token_counts_list), max_rows_to_print)): + print( + f" - Generated Tokens from Worker {i:3.0f}:" + f"{'■' * int(per_worker_load_ratio[i] * 10)}" + f"{'□' * (10 - int(per_worker_load_ratio[i] * 10))}" + f" Count: {per_worker_token_counts_list[i] / 1000:.1f}K" + ) + estimated_idle_ratio = 1 - sum(per_worker_load_ratio) / len( + per_worker_load_ratio + ) + print(f" • Average Token Imbalance: {100 * estimated_idle_ratio:.2f}%") + return estimated_idle_ratio + + print("\n🔍 Performance Metrics:") + performance_metrics = {} + + if "per_worker_token_counts" in metrics: + # Can be a list of each trajectory + if isinstance(metrics["per_worker_token_counts"], list): + per_worker_token_counts = {} + for trajectory_metrics in metrics["per_worker_token_counts"]: + for worker_idx, token_count in trajectory_metrics.items(): + per_worker_token_counts[worker_idx] = ( + per_worker_token_counts.get(worker_idx, 0) + token_count + ) + elif isinstance(metrics["per_worker_token_counts"], dict): + per_worker_token_counts = metrics["per_worker_token_counts"] + else: + per_worker_token_counts = None + + if per_worker_token_counts is not None: + average_token_imbalance = visualize_per_worker_load(per_worker_token_counts) + performance_metrics["average_token_imbalance"] = average_token_imbalance + + # ===================================================== + # Throughputs + # ===================================================== + + policy_and_reference_logprobs_time = timing_metrics["policy_and_reference_logprobs"] + policy_training_time = timing_metrics["policy_training"] + total_time = timing_metrics["total_step_time"] + refit_time = ( + timing_metrics["weight_sync"] + if "weight_sync" in timing_metrics + else timing_metrics["prepare_for_generation/total"] + ) + if "generation" in timing_metrics: # Sync GRPO + generation_time = timing_metrics["generation"] + else: # Async GRPO + # If the training time is greater than the generation time, we include the idle time caused by training as part of the generation time. + # if training time > generation time, generation time = training time + # if training time < generation time, generation time = training time + exposed generation time + generation_time = ( + timing_metrics["exposed_generation"] + + timing_metrics["policy_and_reference_logprobs"] + + timing_metrics["policy_training"] + ) + + num_nodes = master_config["cluster"]["num_nodes"] + gpus_per_node = master_config["cluster"]["gpus_per_node"] + total_num_gpus = num_nodes * gpus_per_node + colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"] + + # Idle Time from Training Worker (Async GRPO only) + if ( + "async_grpo" in master_config and master_config["async_grpo"]["enabled"] + ) and not colocated_inference: + # async grpo + exposed_generation_time = timing_metrics["exposed_generation"] + training_worker_idle_time_ratio = ( + 0 + if exposed_generation_time > 0.1 + else exposed_generation_time + / ( + policy_training_time + + policy_and_reference_logprobs_time + + exposed_generation_time + + refit_time + ) + ) + print( + f" • Training Worker Idle Time Ratio: {100 * training_worker_idle_time_ratio:.2f}%" + ) + performance_metrics["training_worker_idle_time_ratio"] = ( + training_worker_idle_time_ratio + ) + + number_of_samples_per_step = ( + master_config["grpo"]["num_prompts_per_step"] + * master_config["grpo"]["num_generations_per_prompt"] + ) + + if colocated_inference: + training_num_gpus = total_num_gpus + generation_num_gpus = total_num_gpus + else: + generation_num_nodes = ( + master_config["policy"]["generation"]["colocated"]["resources"]["num_nodes"] + or 1 + ) + generation_num_gpus = ( + master_config["policy"]["generation"]["colocated"]["resources"][ + "gpus_per_node" + ] + * generation_num_nodes + ) + training_num_gpus = total_num_gpus - generation_num_gpus + + e2e_samples_per_sec_per_gpu = ( + number_of_samples_per_step / total_time / total_num_gpus + ) + + e2e_tokens_per_sec_per_gpu = ( + metrics["total_num_tokens"] / total_time / total_num_gpus + ) + policy_training_tokens_per_sec_per_gpu = ( + metrics["total_num_tokens"] / policy_training_time / training_num_gpus + ) + policy_and_reference_logprobs_tokens_per_sec_per_gpu = ( + metrics["total_num_tokens"] + / policy_and_reference_logprobs_time + / training_num_gpus + ) + training_worker_group_tokens_per_sec_per_gpu = ( + metrics["total_num_tokens"] + / (policy_training_time + policy_and_reference_logprobs_time) + / training_num_gpus + ) + generation_tokens_per_sec_per_gpu = ( + metrics["total_num_tokens"] / generation_time / generation_num_gpus + ) + + print(" • Throughputs (per GPU):") + print(f" - E2E (Samples/sec/gpu): {e2e_samples_per_sec_per_gpu:.2f}") + print(f" - E2E (Tokens/sec/gpu): {e2e_tokens_per_sec_per_gpu:.2f}") + print( + f" - Policy Training (Tokens/sec/gpu): {policy_training_tokens_per_sec_per_gpu:.2f}" + ) + print( + f" - Policy and Reference Logprobs (Tokens/sec/gpu): {policy_and_reference_logprobs_tokens_per_sec_per_gpu:.2f}" + ) + print( + f" - Training Worker Group (Tokens/sec/gpu): {training_worker_group_tokens_per_sec_per_gpu:.2f}" + ) + print( + f" - Generation Worker Group (Tokens/sec/gpu): {generation_tokens_per_sec_per_gpu:.2f}" + ) + + print(" • Throughputs (per Group):") + print( + f" - E2E (Samples/sec): {(e2e_samples_per_sec_per_gpu * total_num_gpus):.2f}" + ) + print( + f" - E2E (Tokens/sec): {(e2e_tokens_per_sec_per_gpu * total_num_gpus):.2f}" + ) + print( + f" - Training Worker Group (Tokens/sec): {(training_worker_group_tokens_per_sec_per_gpu * training_num_gpus):.2f}" + ) + print( + f" - Generation Worker Group (Tokens/sec): {(generation_tokens_per_sec_per_gpu * generation_num_gpus):.2f}" + ) + + # ===================================================== + # FLOPS + # ===================================================== + + if "total_flops" in train_results: + total_tflops = ( + train_results["total_flops"] / timing_metrics["policy_training"] / 1e12 + ) + num_ranks = train_results["num_ranks"] + print( + f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)", + flush=True, + ) + performance_metrics["train_flops_per_gpu"] = total_tflops / num_ranks + if "theoretical_tflops" in train_results: + theoretical_tflops = train_results["theoretical_tflops"] + print( + f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%", + flush=True, + ) + performance_metrics["train_fp_utilization"] = ( + total_tflops / theoretical_tflops + ) + + # ===================================================== + # Logging + # ===================================================== + + performance_metrics.update( + { + "samples_per_sec": e2e_samples_per_sec_per_gpu * total_num_gpus, + "tokens_per_sec": e2e_tokens_per_sec_per_gpu * total_num_gpus, + "samples_per_sec_per_gpu": e2e_samples_per_sec_per_gpu, + "tokens_per_sec_per_gpu": e2e_tokens_per_sec_per_gpu, + "policy_training_tokens_per_sec_per_gpu": policy_training_tokens_per_sec_per_gpu, + "policy_and_reference_logprobs_tokens_per_sec_per_gpu": policy_and_reference_logprobs_tokens_per_sec_per_gpu, + "training_worker_group_tokens_per_sec_per_gpu": training_worker_group_tokens_per_sec_per_gpu, + "generation_tokens_per_sec_per_gpu": generation_tokens_per_sec_per_gpu, + "training_worker_group_tokens_per_sec": training_worker_group_tokens_per_sec_per_gpu + * training_num_gpus, + "generation_tokens_per_sec": generation_tokens_per_sec_per_gpu + * generation_num_gpus, + } + ) + + return performance_metrics diff --git a/nemo_rl/experience/rollouts.py b/nemo_rl/experience/rollouts.py index e274e343eb..d49cf33706 100644 --- a/nemo_rl/experience/rollouts.py +++ b/nemo_rl/experience/rollouts.py @@ -202,6 +202,16 @@ async def generate_responses_async( "mean_generation_length": generation_lengths.float().mean().item(), "total_generated_tokens": generation_lengths.sum().item(), } + # Attach worker metadata if present (async vLLM path) + if "gen_leader_worker_idx" in generation_outputs: + # generation_outputs carries this as a 1-length list per row; convert to int + v = generation_outputs["gen_leader_worker_idx"][0] + try: + gen_metrics["gen_leader_worker_idx"] = ( + int(v[0]) if isinstance(v, list) else int(v) + ) + except Exception as e: + print(f"Error occurred while extracting gen_leader_worker_idx: {e}") return batch, generated_ids, gen_metrics @@ -515,6 +525,9 @@ def run_multi_turn_rollout( "mean_gen_tokens_per_sample": float( sample_assistant_token_counts.float().mean().item() ), + "max_gen_tokens_per_sample": float( + sample_assistant_token_counts.float().max().item() + ), "mean_env_tokens_per_sample": float( sample_env_token_counts.float().mean().item() ), @@ -635,6 +648,8 @@ async def run_sample_multi_turn_rollout( # Track per-turn metrics turn_gen_tokens = [] + # Track per-turn per-worker token accounting if available + per_worker_token_counts = {} # worker_idx -> token_count for turn in range(max_rollout_turns): if terminated or truncated: @@ -664,6 +679,12 @@ async def run_sample_multi_turn_rollout( assistant_token_count += gen_token_count token_count += gen_token_count turn_gen_tokens.append(gen_token_count) + # Per-worker load accounting + if "gen_leader_worker_idx" in gen_metrics: + worker_idx = int(gen_metrics["gen_leader_worker_idx"]) + per_worker_token_counts[worker_idx] = ( + per_worker_token_counts.get(worker_idx, 0) + gen_token_count + ) except Exception as e: print(f"Error generating response for sample {sample_idx}: {e}") @@ -743,6 +764,8 @@ async def run_sample_multi_turn_rollout( "max_turns_reached": max_turns_reached, "total_reward": total_reward, "turn_gen_tokens": turn_gen_tokens, + # Pass-through per-worker per-turn accounting for aggregation at batch level + "per_worker_token_counts": per_worker_token_counts, } return final_sample_state, sample_metrics @@ -879,6 +902,9 @@ async def run_single_sample_with_error_handling(i, sample_state): m["assistant_tokens"] for m in all_sample_metrics ) / batch_size, + "max_gen_tokens_per_sample": max( + m["assistant_tokens"] for m in all_sample_metrics + ), "mean_env_tokens_per_sample": sum( m["env_tokens"] for m in all_sample_metrics ) @@ -890,6 +916,14 @@ async def run_single_sample_with_error_handling(i, sample_state): "min_total_reward": min(m["total_reward"] for m in all_sample_metrics), } + # Calculate per-worker token counts + if "per_worker_token_counts" in all_sample_metrics[0]: + per_worker_token_counts = {} + for m in all_sample_metrics: + for k, v in m["per_worker_token_counts"].items(): + per_worker_token_counts[k] = per_worker_token_counts.get(k, 0) + v + rollout_metrics["per_worker_token_counts"] = per_worker_token_counts + return final_batch, rollout_metrics return asyncio.run(_async_rollout_implementation()) diff --git a/nemo_rl/models/generation/vllm/vllm_generation.py b/nemo_rl/models/generation/vllm/vllm_generation.py index 7fc44ab450..f67cbea41a 100644 --- a/nemo_rl/models/generation/vllm/vllm_generation.py +++ b/nemo_rl/models/generation/vllm/vllm_generation.py @@ -554,6 +554,12 @@ async def consume_worker_generator(worker_idx, worker_gen): try: async for sample_result_ref in worker_gen: sample_result = await sample_result_ref + # sample_result is a tuple: (original_idx, BatchedDataDict) + # Tag the result with worker index for downstream attribution + original_idx, result_batch = sample_result + # Use a length-one list so BatchedDataDict.from_batches can merge without shape errors + result_batch["gen_leader_worker_idx"] = [int(worker_idx)] + sample_result = (original_idx, result_batch) await result_queue.put(("sample", sample_result)) except Exception as e: # Log the error before putting it in the queue for better debugging diff --git a/tests/unit/algorithms/test_utils.py b/tests/unit/algorithms/test_utils.py index 7562e541ed..ce049a19db 100755 --- a/tests/unit/algorithms/test_utils.py +++ b/tests/unit/algorithms/test_utils.py @@ -12,12 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from datetime import datetime import pytest import torch -from nemo_rl.algorithms.utils import get_tokenizer, maybe_pad_last_batch +from nemo_rl.algorithms.utils import ( + get_tokenizer, + maybe_pad_last_batch, + print_performance_metrics, +) from nemo_rl.data.chat_templates import COMMON_CHAT_TEMPLATES from nemo_rl.distributed.batched_data_dict import BatchedDataDict @@ -212,3 +217,179 @@ def test_maybe_pad_last_batch(): assert result["sample_mask"].shape[0] == expected_size assert "token_mask" not in result assert "reference_policy_logprobs" not in result + + +# Performance Metrics Tests + + +def _base_master_config(colocated: bool): + return { + "cluster": {"num_nodes": 2, "gpus_per_node": 8}, + "policy": { + "generation": { + "colocated": { + "enabled": colocated, + "resources": {"num_nodes": 1, "gpus_per_node": 8}, + } + } + }, + "grpo": {"num_prompts_per_step": 8, "num_generations_per_prompt": 10}, + } + + +def test_sync_colocated_throughput_flops_and_imbalance(capsys): + master_config = _base_master_config(colocated=True) + + timing_metrics = { + "policy_and_reference_logprobs": 2.0, + "policy_training": 4.0, + "total_step_time": 10.0, + "generation": 5.0, + "weight_sync": 1.0, + } + + # total_num_gpus = 2 * 8 = 16 + # samples_per_step = 8 * 10 = 80 + metrics = { + "total_num_tokens": 8000.0, + "per_worker_token_counts": {0: 1000, 1: 2000, 2: 3000, 3: 4000}, + } + + # total_tflops = total_flops / policy_training / 1e12 = 1e15 / 4 / 1e12 = 250 + # per-rank TFLOPS message shows 31.25 TFLOPS per rank for 8 ranks + train_results = { + "total_flops": 1.0e15, + "num_ranks": 8, + "theoretical_tflops": 500.0, + } + + perf = print_performance_metrics( + train_results, metrics, timing_metrics, master_config + ) + + # Validate key throughput metrics + assert math.isclose(perf["samples_per_sec_per_gpu"], 0.5, rel_tol=1e-6) + assert math.isclose(perf["tokens_per_sec_per_gpu"], 50.0, rel_tol=1e-6) + assert math.isclose( + perf["policy_training_tokens_per_sec_per_gpu"], 125.0, rel_tol=1e-6 + ) + assert math.isclose( + perf["policy_and_reference_logprobs_tokens_per_sec_per_gpu"], + 250.0, + rel_tol=1e-6, + ) + assert math.isclose( + perf["training_worker_group_tokens_per_sec_per_gpu"], + 8000.0 / 6.0 / 16.0, + rel_tol=1e-6, + ) + assert math.isclose( + perf["generation_tokens_per_sec_per_gpu"], 8000.0 / 5.0 / 16.0, rel_tol=1e-6 + ) + + # Group totals + assert math.isclose(perf["samples_per_sec"], 8.0, rel_tol=1e-6) + assert math.isclose(perf["tokens_per_sec"], 800.0, rel_tol=1e-6) + assert math.isclose( + perf["training_worker_group_tokens_per_sec"], 8000.0 / 6.0, rel_tol=1e-6 + ) + + # Imbalance metric from ratios [0.25, 0.5, 0.75, 1.0] + assert math.isclose(perf["average_token_imbalance"], 0.375, rel_tol=1e-6) + + # Verify selected console output snippets + out = capsys.readouterr().out + assert "Performance Metrics" in out + assert "Throughputs (per GPU)" in out + assert "Average Token Imbalance" in out + assert "Training FLOPS" in out + assert "Floating Point Utilization" in out + + +def test_async_non_colocated_idle_ratio_and_generation_time(capsys): + master_config = _base_master_config(colocated=False) + master_config["async_grpo"] = {"enabled": True} + + timing_metrics = { + "policy_and_reference_logprobs": 2.0, + "policy_training": 4.0, + "total_step_time": 10.0, + "exposed_generation": 2.0, + "prepare_for_generation/total": 1.0, + } + + # total_num_gpus = 16, training_num_gpus = 8, generation_num_gpus = 8 + metrics = { + "total_num_tokens": 6050.0, + "per_worker_token_counts": [{0: 3000}, {1: 3050}], + } + + train_results = {} + + perf = print_performance_metrics( + train_results, metrics, timing_metrics, master_config + ) + + # Throughput checks + assert math.isclose(perf["samples_per_sec_per_gpu"], 0.5, rel_tol=1e-6) + assert math.isclose( + perf["tokens_per_sec_per_gpu"], 6050.0 / 10.0 / 16.0, rel_tol=1e-6 + ) + assert math.isclose( + perf["policy_training_tokens_per_sec_per_gpu"], + 6050.0 / 4.0 / 8.0, + rel_tol=1e-6, + ) + assert math.isclose( + perf["policy_and_reference_logprobs_tokens_per_sec_per_gpu"], + 6050.0 / 2.0 / 8.0, + rel_tol=1e-6, + ) + assert math.isclose( + perf["training_worker_group_tokens_per_sec_per_gpu"], + 6050.0 / (4.0 + 2.0) / 8.0, + rel_tol=1e-6, + ) + # generation_time = 2 + 2 + 4 = 8.0, per-gpu = 6050 / 8.0 / 8.0 + assert math.isclose( + perf["generation_tokens_per_sec_per_gpu"], 6050.0 / 8.0 / 8.0, rel_tol=1e-6 + ) + + # Aggregated worker counts: {0: 3000, 1: 3050} -> imbalance = 0.05 + imbalance = ((3050 - 3000) / 3050) / 2 + assert math.isclose(perf["average_token_imbalance"], imbalance, rel_tol=1e-6) + + +def test_minimal_inputs_no_counts_no_flops(capsys): + master_config = _base_master_config(colocated=False) + + timing_metrics = { + "policy_and_reference_logprobs": 1.0, + "policy_training": 3.0, + "total_step_time": 8.0, + "exposed_generation": 0.2, + "prepare_for_generation/total": 0.5, + } + + metrics = { + "total_num_tokens": 1600.0, + # no per_worker_token_counts present + } + + train_results = {} + + perf = print_performance_metrics( + train_results, metrics, timing_metrics, master_config + ) + + # Core metrics exist + for k in [ + "samples_per_sec", + "tokens_per_sec", + "samples_per_sec_per_gpu", + "tokens_per_sec_per_gpu", + ]: + assert k in perf + + out = capsys.readouterr().out + assert "Throughputs (per GPU)" in out