Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions hud/rl/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,9 @@ class TrainingConfig(BaseConfig):
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig, description="Optimizer configuration")
max_grad_norm: float = Field(default=1.0, gt=0.0, description="Maximum gradient norm")

# Benchmarking
benchmark: bool = Field(default=False, description="Whether to run in benchmark mode to collect FLOPS and memory usage metrics")

class RewardConfig(BaseConfig):
scale_rewards: Literal["group", "batch", "none"] = Field(default="group", description="Reward scaling strategy")
leave_one_out: bool = Field(default=False, description="RLOO scaling factor G/(G-1), only applies when scale_rewards='none'")
Expand Down
177 changes: 177 additions & 0 deletions hud/rl/perf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import time

from hud.rl.logger import console
import torch
from torch import nn
from transformers import PretrainedConfig
from hud.rl.utils import get_world_size


class PerfCounter:
"""
A class to count throughput (tokens/s) with a rolling window to obtain
precise throughput and MFU estimates.
Inspired from https://github.com/pytorch/torchtitan/blob/4b3f2e41a084bf79a8540068ed525539d1244edd/torchtitan/utils.py#L119
"""

def __init__(self, model: nn.Module, seq_len: int, window_size: int):
self.window_size = window_size
self.tokens = []
self.times = []
self.model = model


if torch.cuda.is_available():
self.gpu_peak_flops = self._get_peak_flops(torch.cuda.get_device_name(torch.device("cuda")))
else:
self.gpu_peak_flops = 0
# If not tie_word_embeddings, we exclude the embedding parameters from the total number of parameters
# If tie_word_embeddings, the embedding parameters are already excluded (shared with the LM head)
self.num_params = self._get_num_params(model, exclude_embedding=not model.config.tie_word_embeddings)
self.num_flop_per_token = self._get_num_flop_per_token(self.num_params, model.config, seq_len=seq_len)

def count_tokens(self, tokens: int):
self.tokens.append(tokens)
self.times.append(time.perf_counter())
if len(self.tokens) > self.window_size:
self.tokens.pop(0)
self.times.pop(0)

def get_tokens_per_second(self) -> float | None:
if len(self.tokens) < 2:
return None
return sum(self.tokens[1:]) / (self.times[-1] - self.times[0])

def get_mfu(self) -> float | None:
tokens_per_second = self.get_tokens_per_second()
if tokens_per_second is None:
return None
return 100 * self.num_flop_per_token * tokens_per_second / self.gpu_peak_flops / get_world_size()
Comment on lines +46 to +50

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Remove extra world-size division in MFU calculation

Each rank already computes MFU in PerfCounter.get_mfu as 100 * flop_per_token * tokens_per_second / gpu_peak_flops / get_world_size(). When the results are gathered later they are averaged across ranks, so the world-size factor is applied twice. With two ranks, a device running at 60% MFU will be reported as only 30%, making the new benchmark output misleading. MFU should be computed per device (no division by world size) and then averaged or aggregated once.

Useful? React with 👍 / 👎.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Double Normalization in Distributed MFU Calculation

The MFU calculation in PerfCounter divides by world_size per rank. These already normalized MFU values are then summed across ranks in train.py, leading to an incorrect, double-normalized MFU metric in distributed training.

Additional Locations (1)

Fix in Cursor Fix in Web


def _get_peak_flops(self, device_name: str) -> float:
"""
Peak BF16 FLOPs (without sparsity)
From: https://github.com/pytorch/torchtitan/blob/05e47c38d99fdb1dd39aeba76f080e529a425c5c/torchtitan/tools/utils.py#L69
"""
if "A100" in device_name:
# https://www.nvidia.com/en-us/data-center/a100/
return 312e12
if "H100" in device_name or "H200" in device_name:
# https://www.nvidia.com/en-us/data-center/h100/
# https://resources.nvidia.com/en-us-data-center-overview-mc/en-us-data-center-overview/hpc-datasheet-sc23-h200
if "NVL" in device_name:
return 835e12
elif "PCIe" in device_name:
return 756e12
else: # For H100 SXM and other variants
return 989e12
if "B200" in device_name:
# https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703
return 2.25e15 # This is half of the FLOPS reported in torchtitan
else:
console.warning_log(f"Peak FLOPS undefined for `{device_name}`. Falling back to A100 (312 TFLOPS)")
return 312e12

@staticmethod
def get_active_mm_params(config: PretrainedConfig) -> float:
"""Get number of active parameters per token involved in matmuls"""
vocab_size = config.vocab_size
hidden_size = config.hidden_size
intermediate_size = config.intermediate_size
num_attention_heads = config.num_attention_heads
head_dim = hidden_size // num_attention_heads
num_hidden_layers = config.num_hidden_layers

## Attention
if hasattr(config, "q_lora_rank") and hasattr(config, "kv_lora_rank"):
# MLA
q_params = num_hidden_layers * (
hidden_size * config.q_lora_rank + config.q_lora_rank * num_attention_heads * config.qk_head_dim
)
kv_params = num_hidden_layers * (
hidden_size * (config.kv_lora_rank + config.qk_rope_head_dim)
+ config.kv_lora_rank * num_attention_heads * (config.qk_nope_head_dim + config.v_head_dim)
)
o_params = num_hidden_layers * (num_attention_heads * config.v_head_dim * hidden_size)
else:
# GQA
num_key_value_heads = config.num_key_value_heads
q_params = num_hidden_layers * hidden_size * num_attention_heads * head_dim
kv_params = 2 * num_hidden_layers * hidden_size * num_key_value_heads * head_dim
o_params = num_hidden_layers * hidden_size * num_attention_heads * head_dim

## MLP
if hasattr(config, "first_k_dense_replace"):
num_dense_layers = config.first_k_dense_replace
num_sparse_layers = config.num_hidden_layers - num_dense_layers
elif hasattr(config, "num_experts_per_tok"):
num_dense_layers = 0
num_sparse_layers = config.num_hidden_layers
else:
num_dense_layers = config.num_hidden_layers
num_sparse_layers = 0

dense_mlp_params = num_dense_layers * 3 * intermediate_size * hidden_size
sparse_mlp_params = 0
if hasattr(config, "num_shared_experts"): # Shared experts
sparse_mlp_params += (
num_sparse_layers * config.num_shared_experts * 3 * config.moe_intermediate_size * hidden_size
)
if hasattr(config, "num_experts_per_tok"): # Routed experts
sparse_mlp_params += (
num_sparse_layers * config.num_experts_per_tok * 3 * config.moe_intermediate_size * hidden_size
)
if hasattr(config, "n_routed_experts"): # DeepSeek Router
sparse_mlp_params += num_sparse_layers * config.n_routed_experts * hidden_size
elif hasattr(config, "num_experts"): # Qwen Router
sparse_mlp_params += num_sparse_layers * config.num_experts * hidden_size
else:
sparse_mlp_params = 0

## LM Head
lm_head_params = vocab_size * hidden_size
## Total
return q_params + kv_params + o_params + dense_mlp_params + sparse_mlp_params + lm_head_params

def _get_num_flop_per_token(self, num_params: int, model_config: PretrainedConfig, seq_len: int) -> int:
l, h, q, t = (
model_config.num_hidden_layers,
model_config.num_attention_heads,
model_config.hidden_size // model_config.num_attention_heads,
seq_len,
)
# Reasoning behind the factor of 12 for the self-attention part of the formula:
# 1. each self-attention has 2 matmul in the forward and 4 in the backward (6)
# 2. the flash attention does 1 more matmul recomputation in the backward
# but recomputation should not be counted in calculating MFU (+0)
# 3. each matmul performs 1 multiplication and 1 addition (*2)
# 4. we follow the convention and do not account for sparsity in causal attention
try:
flop_per_token = 6 * self.get_active_mm_params(model_config) + 12 * l * h * q * t
except Exception as e:
console.warning_log(f"Error calculating flop_per_token using get_active_mm_params: {e}")
flop_per_token = 6 * num_params + 12 * l * h * q * t

return flop_per_token

def _get_num_params(self, model: nn.Module, exclude_embedding: bool = False) -> int:
num_params = sum(p.numel() for p in model.parameters())
if exclude_embedding:
if hasattr(model.lm_head, "weight"):
num_params -= model.lm_head.weight.numel()
elif hasattr(model.lm_head, "base_layer"): # LoRALinear
num_params -= model.lm_head.base_layer.weight.numel()
return num_params


_PERF_COUNTER: PerfCounter | None = None


def get_perf_counter(model: nn.Module, seq_len: int, window_size: int = 10) -> PerfCounter:
global _PERF_COUNTER
if _PERF_COUNTER is None:
_PERF_COUNTER = PerfCounter(model, seq_len, window_size)

return _PERF_COUNTER
7 changes: 4 additions & 3 deletions hud/rl/tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@

def main() -> None:
training_config = TrainingConfig()
training_config.model = ModelConfig(base_model="Qwen/Qwen2.5-VL-7B-Instruct")
training_config.model = ModelConfig(base_model="Qwen/Qwen2.5-VL-3B-Instruct")
training_config.dp_shard = 2
training_config.optimizer.use_8bit_optimizer = False
training_config.loss.kl_beta = 0.0
training_config.output_dir = "/home/ubuntu/hud-python/hud/rl/tests/outputs"
training_config.output_dir = "/home/ubuntu/myworkspace/hud-python/hud/rl/tests/outputs"
training_config.benchmark = True

console.info("=" * 80)
console.info("Running trainer...")

train(training_config, max_steps=1)
train(training_config, max_steps=5)

if __name__ == "__main__":
main()
20 changes: 11 additions & 9 deletions hud/rl/tests/utils/prepare_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def resolve_pad_token_id(processor):

def main():
config = Config()
trace_file = "/home/ubuntu/hud-python/hud/rl/tests/data/traces_de8ea147-3c52-4117-ad24-d1dbaa39a088.json"
trace_file = "/home/ubuntu/myworkspace/hud-python/hud/rl/tests/data/traces_de8ea147-3c52-4117-ad24-d1dbaa39a088.json"

print("=" * 80)
print("Loading traces from dump...")
Expand All @@ -41,7 +41,7 @@ def main():
pad_token_id = resolve_pad_token_id(processor)

group_size = 8
num_traces = min(len(traces), 32)
num_traces = min(len(traces), 16)
traces = traces[:num_traces]

rewards = torch.tensor([float(trace.reward) for trace in traces], dtype=torch.float32)
Expand Down Expand Up @@ -101,15 +101,17 @@ def main():

tests_root = Path(__file__).resolve().parents[1]
outputs_root = tests_root / "outputs"
step_dir = outputs_root / "step_00000" / "rollouts"
step_dir.mkdir(parents=True, exist_ok=True)

for gpu_idx, gpu_batch in enumerate(training_batch):
output_file = step_dir / f"rank_{gpu_idx}.pt"
torch.save(gpu_batch, output_file)
print(f" GPU {gpu_idx}: {output_file}")
for step in range(5):
step_dir = outputs_root / f"step_{step:05d}" / "rollouts"
step_dir.mkdir(parents=True, exist_ok=True)

print("Done!")
for gpu_idx, gpu_batch in enumerate(training_batch):
output_file = step_dir / f"rank_{gpu_idx}.pt"
torch.save(gpu_batch, output_file)
print(f" GPU {gpu_idx}: {output_file}")

print("Done!")


if __name__ == "__main__":
Expand Down
72 changes: 69 additions & 3 deletions hud/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from hud.rl.checkpoint import CheckpointManager
from hud.rl.utils import save_step_metrics
from hud.rl.types import TrainingSample
from hud.rl.perf import PerfCounter
from rich.table import Table


def get_batch(step: int, root: str) -> list[TrainingSample]:
Expand Down Expand Up @@ -54,6 +56,12 @@ def train(

console.section_title("Initializing trainer")

if training_config.benchmark:
if is_main_process():
console.warning_log("Running in benchmark mode, overriding max_steps to 5")
max_steps = min(max_steps, 5)


parallel_dims = ParallelDims(
dp_replicate=training_config.dp_replicate,
dp_shard=training_config.dp_shard,
Expand All @@ -67,6 +75,8 @@ def train(

model = build_model(training_config, parallel_dims)

benchmark_data = []

ref_model: torch.nn.Module | None = None
if training_config.loss.kl_beta > 0:
console.info_log("Initializing reference model for KL regularization")
Expand All @@ -82,6 +92,8 @@ def train(

collector = MetricsCollector(distributed=(world_size > 1))

perf_counter: PerfCounter | None = None

for step in range(max_steps):
collector.reset()
# Save checkpoint from previous step (skip first step since no training yet)
Expand All @@ -107,7 +119,9 @@ def train(
del logits
progress.update(f"Computing reference log probabilities... {i + 1}/{len(batch)}")


if perf_counter is None:
perf_counter = PerfCounter(model, batch[0].inputs["input_ids"].shape[1], 10)
perf_counter.count_tokens(0)

with console.progress("Computing old log probabilities...") as progress, torch.no_grad():
for i, minibatch in enumerate(batch):
Expand Down Expand Up @@ -193,18 +207,70 @@ def train(
step_duration = time.time() - training_start_time
console.info_log(f"Step {step} training took {step_duration:.2f} seconds")


# Collect performance data
num_tokens = sum(minibatch.inputs["input_ids"].shape[1] for minibatch in batch)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Token Counting Error Affects Performance Metrics

The num_tokens calculation at line 212 incorrectly sums only the sequence lengths across minibatches, overlooking the batch dimension. This undercounts the total tokens processed, resulting in inaccurate throughput and MFU metrics reported by the performance counter.

Fix in Cursor Fix in Web

console.warning_log(f"num_tokens: {num_tokens}")
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: Warning Logs Misused for Routine Info

The console.warning_log for num_tokens logs routine information at the warning level. This can obscure actual warnings and appears to be leftover debug output.

Fix in Cursor Fix in Web

perf_counter.count_tokens(num_tokens) # Add to rolling window
throughput = perf_counter.get_tokens_per_second() or 0
mfu = perf_counter.get_mfu() or 0
peak_memory = torch.cuda.max_memory_allocated() / 1024**3

dist_perf_output_list = [{}] * world_size
torch.distributed.all_gather_object(dist_perf_output_list, {
"step_duration": step_duration,
"throughput": throughput,
"mfu": mfu,
"peak_memory": peak_memory,
})

benchmark_data.append({
"step": step,
# max step duration across ranks
"step_duration": max([x["step_duration"] for x in dist_perf_output_list]),
# sum throughput across ranks
"throughput": sum([x["throughput"] for x in dist_perf_output_list]),
# average mfu across ranks
"mfu": sum([x["mfu"] for x in dist_perf_output_list])/world_size,
# sum peak memory across ranks
"peak_memory": max([x["peak_memory"] for x in dist_perf_output_list]),
})

stats = collector.get_stats()
if is_main_process():
save_step_metrics(training_config.output_dir, step, stats)



torch.cuda.empty_cache()

# Save final checkpoint after last training step
if max_steps > 0:
console.info(f"Saving final checkpoint for step {max_steps - 1}...")
checkpoint_manager.save(model, max_steps - 1)

if training_config.benchmark:
# Create benchmark table
table = Table(title="Training Performance Metrics")
table.add_column("Step", justify="right", style="cyan")
table.add_column("Duration (s)", justify="right", style="yellow")
table.add_column("Throughput (tok/s)", justify="right", style="green")
table.add_column("MFU (%)", justify="right", style="green")
table.add_column("Peak Memory (GB)", justify="right", style="red")

# Add rows
for data in benchmark_data:
table.add_row(
str(data["step"]),
f"{data['step_duration']:.2f}",
f"{data['throughput']:.0f}",
f"{data['mfu']:.2f}",
f"{data['peak_memory']:.2f}",
)

if is_main_process():
console.section_title("Benchmark Results")
console.print(table)



def main() -> None:
"""Main entry point for training script."""
Expand Down
Loading