Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 27 additions & 27 deletions nemo_rl/algorithms/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
ClippedPGLossDataDict,
ClippedPGLossFn,
)
from nemo_rl.algorithms.utils import calculate_baseline_and_std_per_prompt, set_seed
from nemo_rl.algorithms.utils import (
calculate_baseline_and_std_per_prompt,
print_performance_metrics,
set_seed,
)
from nemo_rl.data import DataConfig
from nemo_rl.data.collate_fn import rl_collate_fn
from nemo_rl.data.datasets import AllTaskProcessedDataset
Expand Down Expand Up @@ -954,7 +958,6 @@ def grpo_train(
total_steps + 1,
name="train/token_mult_prob_error_plot_sample",
)

print("\n📊 Training Results:")

print(f" • Loss: {metrics['loss']:.4f}")
Expand All @@ -963,40 +966,19 @@ def grpo_train(
f" • Mean Generation Length: {rollout_metrics['mean_gen_tokens_per_sample']:.4f}",
flush=True,
)
if "total_flops" in train_results:
total_tflops = (
train_results["total_flops"]
/ timing_metrics["policy_training"]
/ 1e12
)
num_ranks = train_results["num_ranks"]
print(
f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)",
flush=True,
)
if "theoretical_tflops" in train_results:
theoretical_tflops = train_results["theoretical_tflops"]
print(
f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%",
flush=True,
)
metrics["train_fp_utilization"] = total_tflops / theoretical_tflops

print("\n⏱️ Timing:", flush=True)
# Display total time first, separately
total_time = timing_metrics.get("total_step_time", 0)

number_of_samples_per_step = (
master_config["grpo"]["num_prompts_per_step"]
* master_config["grpo"]["num_generations_per_prompt"]
)
total_num_gpus = (
master_config["cluster"]["num_nodes"]
* master_config["cluster"]["gpus_per_node"]
)
metrics.update(
{
"tokens_per_sec_per_gpu": metrics["total_num_tokens"]
/ total_time
/ total_num_gpus
}
)

print(f" • Total step time: {total_time:.2f}s", flush=True)

Expand All @@ -1008,7 +990,14 @@ def grpo_train(
percent = (v / total_time * 100) if total_time > 0 else 0
print(f" • {k}: {v:.2f}s ({percent:.1f}%)", flush=True)

performance_metrics = print_performance_metrics(
train_results, metrics, timing_metrics, master_config
)

logger.log_metrics(metrics, total_steps + 1, prefix="train")
logger.log_metrics(
performance_metrics, total_steps + 1, prefix="performance"
)
logger.log_metrics(timing_metrics, total_steps + 1, prefix="timing/train")

timer.reset()
Expand Down Expand Up @@ -1446,6 +1435,7 @@ def async_grpo_train(
for t in trajectories:
for k, v in t["rollout_metrics"].items():
rollout_metrics.setdefault(k, []).append(v)
# TODO: this simple averaging might cause misleading information for such data as max_gen_tokens, etc.
rollout_metrics = {
k: (sum(v) / len(v) if isinstance(v[0], (int, float)) else v)
for k, v in rollout_metrics.items()
Expand Down Expand Up @@ -1711,6 +1701,8 @@ def async_grpo_train(
"loss": train_results["loss"].numpy(),
"reward": rewards.numpy(),
"grad_norm": train_results["grad_norm"].numpy(),
"mean_prompt_length": repeated_batch["length"].numpy(),
"total_num_tokens": input_lengths.numpy(),
}
metrics.update(train_results["all_mb_metrics"])
for k, v in metrics.items():
Expand Down Expand Up @@ -1751,6 +1743,14 @@ def async_grpo_train(
percent = (v / total_time * 100) if total_time > 0 else 0
print(f" • {k}: {v:.2f}s ({percent:.1f}%)")

performance_metrics = print_performance_metrics(
train_results, metrics, timing_metrics, master_config
)

if "per_worker_token_counts" in metrics:
del metrics["per_worker_token_counts"]

logger.log_metrics(performance_metrics, step + 1, prefix="performance")
logger.log_metrics(metrics, step + 1, prefix="train")
logger.log_metrics(timing_metrics, step + 1, prefix="timing/train")

Expand Down
230 changes: 230 additions & 0 deletions nemo_rl/algorithms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,233 @@ def maybe_pad_last_batch(batch: dict, dp_size: int, mbs: int) -> dict:
]
)
return batch


def print_performance_metrics(
train_results: dict[str, float],
metrics: dict[str, float],
timing_metrics: dict[str, float],
master_config: dict,
) -> dict[str, float]:
"""Print performance metrics for GRPO."""

# =====================================================
# Generate Token Imbalance Visualization
# =====================================================
def visualize_per_worker_load(per_worker_token_counts: dict[int, int]) -> float:
per_worker_token_counts_list = [
v for k, v in sorted(per_worker_token_counts.items())
]
per_worker_load_ratio = [
v / max(per_worker_token_counts_list) for v in per_worker_token_counts_list
]
max_rows_to_print = 100
print(" • Visualizing Token Imbalance per Generation Worker:")
for i in range(min(len(per_worker_token_counts_list), max_rows_to_print)):
print(
f" - Generated Tokens from Worker {i:3.0f}:"
f"{'■' * int(per_worker_load_ratio[i] * 10)}"
f"{'□' * (10 - int(per_worker_load_ratio[i] * 10))}"
f" Count: {per_worker_token_counts_list[i] / 1000:.1f}K"
)
estimated_idle_ratio = 1 - sum(per_worker_load_ratio) / len(
per_worker_load_ratio
)
print(f" • Average Token Imbalance: {100 * estimated_idle_ratio:.2f}%")
return estimated_idle_ratio

print("\n🔍 Performance Metrics:")
performance_metrics = {}

if "per_worker_token_counts" in metrics:
# Can be a list of each trajectory
if isinstance(metrics["per_worker_token_counts"], list):
per_worker_token_counts = {}
for trajectory_metrics in metrics["per_worker_token_counts"]:
for worker_idx, token_count in trajectory_metrics.items():
per_worker_token_counts[worker_idx] = (
per_worker_token_counts.get(worker_idx, 0) + token_count
)
elif isinstance(metrics["per_worker_token_counts"], dict):
per_worker_token_counts = metrics["per_worker_token_counts"]
else:
per_worker_token_counts = None

if per_worker_token_counts is not None:
average_token_imbalance = visualize_per_worker_load(per_worker_token_counts)
performance_metrics["average_token_imbalance"] = average_token_imbalance

# =====================================================
# Throughputs
# =====================================================

policy_and_reference_logprobs_time = timing_metrics["policy_and_reference_logprobs"]
policy_training_time = timing_metrics["policy_training"]
total_time = timing_metrics["total_step_time"]
refit_time = (
timing_metrics["weight_sync"]
if "weight_sync" in timing_metrics
else timing_metrics["prepare_for_generation/total"]
)
if "generation" in timing_metrics: # Sync GRPO
generation_time = timing_metrics["generation"]
else: # Async GRPO
# If the training time is greater than the generation time, we include the idle time caused by training as part of the generation time.
# if training time > generation time, generation time = training time
# if training time < generation time, generation time = training time + exposed generation time
generation_time = (
timing_metrics["exposed_generation"]
+ timing_metrics["policy_and_reference_logprobs"]
+ timing_metrics["policy_training"]
)

num_nodes = master_config["cluster"]["num_nodes"]
gpus_per_node = master_config["cluster"]["gpus_per_node"]
total_num_gpus = num_nodes * gpus_per_node
colocated_inference = master_config["policy"]["generation"]["colocated"]["enabled"]

# Idle Time from Training Worker (Async GRPO only)
if (
"async_grpo" in master_config and master_config["async_grpo"]["enabled"]
) and not colocated_inference:
# async grpo
exposed_generation_time = timing_metrics["exposed_generation"]
training_worker_idle_time_ratio = (
0
if exposed_generation_time > 0.1
else exposed_generation_time
/ (
policy_training_time
+ policy_and_reference_logprobs_time
+ exposed_generation_time
+ refit_time
)
)
print(
f" • Training Worker Idle Time Ratio: {100 * training_worker_idle_time_ratio:.2f}%"
)
performance_metrics["training_worker_idle_time_ratio"] = (
training_worker_idle_time_ratio
)

number_of_samples_per_step = (
master_config["grpo"]["num_prompts_per_step"]
* master_config["grpo"]["num_generations_per_prompt"]
)

if colocated_inference:
training_num_gpus = total_num_gpus
generation_num_gpus = total_num_gpus
else:
generation_num_nodes = (
master_config["policy"]["generation"]["colocated"]["resources"]["num_nodes"]
or 1
)
generation_num_gpus = (
master_config["policy"]["generation"]["colocated"]["resources"][
"gpus_per_node"
]
* generation_num_nodes
)
training_num_gpus = total_num_gpus - generation_num_gpus

e2e_samples_per_sec_per_gpu = (
number_of_samples_per_step / total_time / total_num_gpus
)

e2e_tokens_per_sec_per_gpu = (
metrics["total_num_tokens"] / total_time / total_num_gpus
)
policy_training_tokens_per_sec_per_gpu = (
metrics["total_num_tokens"] / policy_training_time / training_num_gpus
)
policy_and_reference_logprobs_tokens_per_sec_per_gpu = (
metrics["total_num_tokens"]
/ policy_and_reference_logprobs_time
/ training_num_gpus
)
training_worker_group_tokens_per_sec_per_gpu = (
metrics["total_num_tokens"]
/ (policy_training_time + policy_and_reference_logprobs_time)
/ training_num_gpus
)
generation_tokens_per_sec_per_gpu = (
metrics["total_num_tokens"] / generation_time / generation_num_gpus
)

print(" • Throughputs (per GPU):")
print(f" - E2E (Samples/sec/gpu): {e2e_samples_per_sec_per_gpu:.2f}")
print(f" - E2E (Tokens/sec/gpu): {e2e_tokens_per_sec_per_gpu:.2f}")
print(
f" - Policy Training (Tokens/sec/gpu): {policy_training_tokens_per_sec_per_gpu:.2f}"
)
print(
f" - Policy and Reference Logprobs (Tokens/sec/gpu): {policy_and_reference_logprobs_tokens_per_sec_per_gpu:.2f}"
)
print(
f" - Training Worker Group (Tokens/sec/gpu): {training_worker_group_tokens_per_sec_per_gpu:.2f}"
)
print(
f" - Generation Worker Group (Tokens/sec/gpu): {generation_tokens_per_sec_per_gpu:.2f}"
)

print(" • Throughputs (per Group):")
print(
f" - E2E (Samples/sec): {(e2e_samples_per_sec_per_gpu * total_num_gpus):.2f}"
)
print(
f" - E2E (Tokens/sec): {(e2e_tokens_per_sec_per_gpu * total_num_gpus):.2f}"
)
print(
f" - Training Worker Group (Tokens/sec): {(training_worker_group_tokens_per_sec_per_gpu * training_num_gpus):.2f}"
)
print(
f" - Generation Worker Group (Tokens/sec): {(generation_tokens_per_sec_per_gpu * generation_num_gpus):.2f}"
)

# =====================================================
# FLOPS
# =====================================================

if "total_flops" in train_results:
total_tflops = (
train_results["total_flops"] / timing_metrics["policy_training"] / 1e12
)
num_ranks = train_results["num_ranks"]
print(
f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)",
flush=True,
)
performance_metrics["train_flops_per_gpu"] = total_tflops / num_ranks
if "theoretical_tflops" in train_results:
theoretical_tflops = train_results["theoretical_tflops"]
print(
f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%",
flush=True,
)
performance_metrics["train_fp_utilization"] = (
total_tflops / theoretical_tflops
)

# =====================================================
# Logging
# =====================================================

performance_metrics.update(
{
"samples_per_sec": e2e_samples_per_sec_per_gpu * total_num_gpus,
"tokens_per_sec": e2e_tokens_per_sec_per_gpu * total_num_gpus,
"samples_per_sec_per_gpu": e2e_samples_per_sec_per_gpu,
"tokens_per_sec_per_gpu": e2e_tokens_per_sec_per_gpu,
"policy_training_tokens_per_sec_per_gpu": policy_training_tokens_per_sec_per_gpu,
"policy_and_reference_logprobs_tokens_per_sec_per_gpu": policy_and_reference_logprobs_tokens_per_sec_per_gpu,
"training_worker_group_tokens_per_sec_per_gpu": training_worker_group_tokens_per_sec_per_gpu,
"generation_tokens_per_sec_per_gpu": generation_tokens_per_sec_per_gpu,
"training_worker_group_tokens_per_sec": training_worker_group_tokens_per_sec_per_gpu
* training_num_gpus,
"generation_tokens_per_sec": generation_tokens_per_sec_per_gpu
* generation_num_gpus,
}
)

return performance_metrics
Loading
Loading