diff --git a/hud/rl/config.py b/hud/rl/config.py index 65973472..40480aae 100644 --- a/hud/rl/config.py +++ b/hud/rl/config.py @@ -190,6 +190,9 @@ class TrainingConfig(BaseConfig): optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig, description="Optimizer configuration") max_grad_norm: float = Field(default=1.0, gt=0.0, description="Maximum gradient norm") + # Benchmarking + benchmark: bool = Field(default=False, description="Whether to run in benchmark mode to collect FLOPS and memory usage metrics") + class RewardConfig(BaseConfig): scale_rewards: Literal["group", "batch", "none"] = Field(default="group", description="Reward scaling strategy") leave_one_out: bool = Field(default=False, description="RLOO scaling factor G/(G-1), only applies when scale_rewards='none'") diff --git a/hud/rl/perf.py b/hud/rl/perf.py new file mode 100644 index 00000000..efcf702d --- /dev/null +++ b/hud/rl/perf.py @@ -0,0 +1,177 @@ +import time + +from hud.rl.logger import console +import torch +from torch import nn +from transformers import PretrainedConfig +from hud.rl.utils import get_world_size + + +class PerfCounter: + """ + A class to count throughput (tokens/s) with a rolling window to obtain + precise throughput and MFU estimates. + + Inspired from https://github.com/pytorch/torchtitan/blob/4b3f2e41a084bf79a8540068ed525539d1244edd/torchtitan/utils.py#L119 + """ + + def __init__(self, model: nn.Module, seq_len: int, window_size: int): + self.window_size = window_size + self.tokens = [] + self.times = [] + self.model = model + + + if torch.cuda.is_available(): + self.gpu_peak_flops = self._get_peak_flops(torch.cuda.get_device_name(torch.device("cuda"))) + else: + self.gpu_peak_flops = 0 + # If not tie_word_embeddings, we exclude the embedding parameters from the total number of parameters + # If tie_word_embeddings, the embedding parameters are already excluded (shared with the LM head) + self.num_params = self._get_num_params(model, exclude_embedding=not model.config.tie_word_embeddings) + self.num_flop_per_token = self._get_num_flop_per_token(self.num_params, model.config, seq_len=seq_len) + + def count_tokens(self, tokens: int): + self.tokens.append(tokens) + self.times.append(time.perf_counter()) + if len(self.tokens) > self.window_size: + self.tokens.pop(0) + self.times.pop(0) + + def get_tokens_per_second(self) -> float | None: + if len(self.tokens) < 2: + return None + return sum(self.tokens[1:]) / (self.times[-1] - self.times[0]) + + def get_mfu(self) -> float | None: + tokens_per_second = self.get_tokens_per_second() + if tokens_per_second is None: + return None + return 100 * self.num_flop_per_token * tokens_per_second / self.gpu_peak_flops / get_world_size() + + def _get_peak_flops(self, device_name: str) -> float: + """ + Peak BF16 FLOPs (without sparsity) + + From: https://github.com/pytorch/torchtitan/blob/05e47c38d99fdb1dd39aeba76f080e529a425c5c/torchtitan/tools/utils.py#L69 + """ + if "A100" in device_name: + # https://www.nvidia.com/en-us/data-center/a100/ + return 312e12 + if "H100" in device_name or "H200" in device_name: + # https://www.nvidia.com/en-us/data-center/h100/ + # https://resources.nvidia.com/en-us-data-center-overview-mc/en-us-data-center-overview/hpc-datasheet-sc23-h200 + if "NVL" in device_name: + return 835e12 + elif "PCIe" in device_name: + return 756e12 + else: # For H100 SXM and other variants + return 989e12 + if "B200" in device_name: + # https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703 + return 2.25e15 # This is half of the FLOPS reported in torchtitan + else: + console.warning_log(f"Peak FLOPS undefined for `{device_name}`. Falling back to A100 (312 TFLOPS)") + return 312e12 + + @staticmethod + def get_active_mm_params(config: PretrainedConfig) -> float: + """Get number of active parameters per token involved in matmuls""" + vocab_size = config.vocab_size + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + num_attention_heads = config.num_attention_heads + head_dim = hidden_size // num_attention_heads + num_hidden_layers = config.num_hidden_layers + + ## Attention + if hasattr(config, "q_lora_rank") and hasattr(config, "kv_lora_rank"): + # MLA + q_params = num_hidden_layers * ( + hidden_size * config.q_lora_rank + config.q_lora_rank * num_attention_heads * config.qk_head_dim + ) + kv_params = num_hidden_layers * ( + hidden_size * (config.kv_lora_rank + config.qk_rope_head_dim) + + config.kv_lora_rank * num_attention_heads * (config.qk_nope_head_dim + config.v_head_dim) + ) + o_params = num_hidden_layers * (num_attention_heads * config.v_head_dim * hidden_size) + else: + # GQA + num_key_value_heads = config.num_key_value_heads + q_params = num_hidden_layers * hidden_size * num_attention_heads * head_dim + kv_params = 2 * num_hidden_layers * hidden_size * num_key_value_heads * head_dim + o_params = num_hidden_layers * hidden_size * num_attention_heads * head_dim + + ## MLP + if hasattr(config, "first_k_dense_replace"): + num_dense_layers = config.first_k_dense_replace + num_sparse_layers = config.num_hidden_layers - num_dense_layers + elif hasattr(config, "num_experts_per_tok"): + num_dense_layers = 0 + num_sparse_layers = config.num_hidden_layers + else: + num_dense_layers = config.num_hidden_layers + num_sparse_layers = 0 + + dense_mlp_params = num_dense_layers * 3 * intermediate_size * hidden_size + sparse_mlp_params = 0 + if hasattr(config, "num_shared_experts"): # Shared experts + sparse_mlp_params += ( + num_sparse_layers * config.num_shared_experts * 3 * config.moe_intermediate_size * hidden_size + ) + if hasattr(config, "num_experts_per_tok"): # Routed experts + sparse_mlp_params += ( + num_sparse_layers * config.num_experts_per_tok * 3 * config.moe_intermediate_size * hidden_size + ) + if hasattr(config, "n_routed_experts"): # DeepSeek Router + sparse_mlp_params += num_sparse_layers * config.n_routed_experts * hidden_size + elif hasattr(config, "num_experts"): # Qwen Router + sparse_mlp_params += num_sparse_layers * config.num_experts * hidden_size + else: + sparse_mlp_params = 0 + + ## LM Head + lm_head_params = vocab_size * hidden_size + ## Total + return q_params + kv_params + o_params + dense_mlp_params + sparse_mlp_params + lm_head_params + + def _get_num_flop_per_token(self, num_params: int, model_config: PretrainedConfig, seq_len: int) -> int: + l, h, q, t = ( + model_config.num_hidden_layers, + model_config.num_attention_heads, + model_config.hidden_size // model_config.num_attention_heads, + seq_len, + ) + # Reasoning behind the factor of 12 for the self-attention part of the formula: + # 1. each self-attention has 2 matmul in the forward and 4 in the backward (6) + # 2. the flash attention does 1 more matmul recomputation in the backward + # but recomputation should not be counted in calculating MFU (+0) + # 3. each matmul performs 1 multiplication and 1 addition (*2) + # 4. we follow the convention and do not account for sparsity in causal attention + try: + flop_per_token = 6 * self.get_active_mm_params(model_config) + 12 * l * h * q * t + except Exception as e: + console.warning_log(f"Error calculating flop_per_token using get_active_mm_params: {e}") + flop_per_token = 6 * num_params + 12 * l * h * q * t + + return flop_per_token + + def _get_num_params(self, model: nn.Module, exclude_embedding: bool = False) -> int: + num_params = sum(p.numel() for p in model.parameters()) + if exclude_embedding: + if hasattr(model.lm_head, "weight"): + num_params -= model.lm_head.weight.numel() + elif hasattr(model.lm_head, "base_layer"): # LoRALinear + num_params -= model.lm_head.base_layer.weight.numel() + return num_params + + +_PERF_COUNTER: PerfCounter | None = None + + +def get_perf_counter(model: nn.Module, seq_len: int, window_size: int = 10) -> PerfCounter: + global _PERF_COUNTER + if _PERF_COUNTER is None: + _PERF_COUNTER = PerfCounter(model, seq_len, window_size) + + return _PERF_COUNTER diff --git a/hud/rl/tests/test_train.py b/hud/rl/tests/test_train.py index f5b5eafb..4a23b1d3 100644 --- a/hud/rl/tests/test_train.py +++ b/hud/rl/tests/test_train.py @@ -5,16 +5,17 @@ def main() -> None: training_config = TrainingConfig() - training_config.model = ModelConfig(base_model="Qwen/Qwen2.5-VL-7B-Instruct") + training_config.model = ModelConfig(base_model="Qwen/Qwen2.5-VL-3B-Instruct") training_config.dp_shard = 2 training_config.optimizer.use_8bit_optimizer = False training_config.loss.kl_beta = 0.0 - training_config.output_dir = "/home/ubuntu/hud-python/hud/rl/tests/outputs" + training_config.output_dir = "/home/ubuntu/myworkspace/hud-python/hud/rl/tests/outputs" + training_config.benchmark = True console.info("=" * 80) console.info("Running trainer...") - train(training_config, max_steps=1) + train(training_config, max_steps=5) if __name__ == "__main__": main() diff --git a/hud/rl/tests/utils/prepare_batch.py b/hud/rl/tests/utils/prepare_batch.py index b19c733a..ebffb4dd 100644 --- a/hud/rl/tests/utils/prepare_batch.py +++ b/hud/rl/tests/utils/prepare_batch.py @@ -24,7 +24,7 @@ def resolve_pad_token_id(processor): def main(): config = Config() - trace_file = "/home/ubuntu/hud-python/hud/rl/tests/data/traces_de8ea147-3c52-4117-ad24-d1dbaa39a088.json" + trace_file = "/home/ubuntu/myworkspace/hud-python/hud/rl/tests/data/traces_de8ea147-3c52-4117-ad24-d1dbaa39a088.json" print("=" * 80) print("Loading traces from dump...") @@ -41,7 +41,7 @@ def main(): pad_token_id = resolve_pad_token_id(processor) group_size = 8 - num_traces = min(len(traces), 32) + num_traces = min(len(traces), 16) traces = traces[:num_traces] rewards = torch.tensor([float(trace.reward) for trace in traces], dtype=torch.float32) @@ -101,15 +101,17 @@ def main(): tests_root = Path(__file__).resolve().parents[1] outputs_root = tests_root / "outputs" - step_dir = outputs_root / "step_00000" / "rollouts" - step_dir.mkdir(parents=True, exist_ok=True) - for gpu_idx, gpu_batch in enumerate(training_batch): - output_file = step_dir / f"rank_{gpu_idx}.pt" - torch.save(gpu_batch, output_file) - print(f" GPU {gpu_idx}: {output_file}") + for step in range(5): + step_dir = outputs_root / f"step_{step:05d}" / "rollouts" + step_dir.mkdir(parents=True, exist_ok=True) - print("Done!") + for gpu_idx, gpu_batch in enumerate(training_batch): + output_file = step_dir / f"rank_{gpu_idx}.pt" + torch.save(gpu_batch, output_file) + print(f" GPU {gpu_idx}: {output_file}") + + print("Done!") if __name__ == "__main__": diff --git a/hud/rl/train.py b/hud/rl/train.py index ba3a4cc8..f8dbeb7c 100644 --- a/hud/rl/train.py +++ b/hud/rl/train.py @@ -26,6 +26,8 @@ from hud.rl.checkpoint import CheckpointManager from hud.rl.utils import save_step_metrics from hud.rl.types import TrainingSample +from hud.rl.perf import PerfCounter +from rich.table import Table def get_batch(step: int, root: str) -> list[TrainingSample]: @@ -54,6 +56,12 @@ def train( console.section_title("Initializing trainer") + if training_config.benchmark: + if is_main_process(): + console.warning_log("Running in benchmark mode, overriding max_steps to 5") + max_steps = min(max_steps, 5) + + parallel_dims = ParallelDims( dp_replicate=training_config.dp_replicate, dp_shard=training_config.dp_shard, @@ -67,6 +75,8 @@ def train( model = build_model(training_config, parallel_dims) + benchmark_data = [] + ref_model: torch.nn.Module | None = None if training_config.loss.kl_beta > 0: console.info_log("Initializing reference model for KL regularization") @@ -82,6 +92,8 @@ def train( collector = MetricsCollector(distributed=(world_size > 1)) + perf_counter: PerfCounter | None = None + for step in range(max_steps): collector.reset() # Save checkpoint from previous step (skip first step since no training yet) @@ -107,7 +119,9 @@ def train( del logits progress.update(f"Computing reference log probabilities... {i + 1}/{len(batch)}") - + if perf_counter is None: + perf_counter = PerfCounter(model, batch[0].inputs["input_ids"].shape[1], 10) + perf_counter.count_tokens(0) with console.progress("Computing old log probabilities...") as progress, torch.no_grad(): for i, minibatch in enumerate(batch): @@ -193,11 +207,39 @@ def train( step_duration = time.time() - training_start_time console.info_log(f"Step {step} training took {step_duration:.2f} seconds") + + # Collect performance data + # sum batch size * sequence length for each minibatch + num_tokens = sum(minibatch.inputs["input_ids"].shape[1] * minibatch.inputs["input_ids"].shape[0] for minibatch in batch) + perf_counter.count_tokens(num_tokens) # Add to rolling window + throughput = perf_counter.get_tokens_per_second() or 0 + mfu = perf_counter.get_mfu() or 0 + peak_memory = torch.cuda.max_memory_allocated() / 1024**3 + + dist_perf_output_list = [{}] * world_size + torch.distributed.all_gather_object(dist_perf_output_list, { + "step_duration": step_duration, + "throughput": throughput, + "mfu": mfu, + "peak_memory": peak_memory, + }) + + benchmark_data.append({ + "step": step, + # max step duration across ranks + "step_duration": max([x["step_duration"] for x in dist_perf_output_list]), + # sum throughput across ranks + "throughput": sum([x["throughput"] for x in dist_perf_output_list]), + # sum mfu across ranks (already normalized by world size) + "mfu": sum([x["mfu"] for x in dist_perf_output_list]), + # sum peak memory across ranks + "peak_memory": max([x["peak_memory"] for x in dist_perf_output_list]), + }) + stats = collector.get_stats() if is_main_process(): save_step_metrics(training_config.output_dir, step, stats) - - + torch.cuda.empty_cache() # Save final checkpoint after last training step @@ -205,6 +247,30 @@ def train( console.info(f"Saving final checkpoint for step {max_steps - 1}...") checkpoint_manager.save(model, max_steps - 1) + if training_config.benchmark: + # Create benchmark table + table = Table(title="Training Performance Metrics") + table.add_column("Step", justify="right", style="cyan") + table.add_column("Duration (s)", justify="right", style="yellow") + table.add_column("Throughput (tok/s)", justify="right", style="green") + table.add_column("MFU (%)", justify="right", style="green") + table.add_column("Peak Memory (GB)", justify="right", style="red") + + # Add rows + for data in benchmark_data: + table.add_row( + str(data["step"]), + f"{data['step_duration']:.2f}", + f"{data['throughput']:.0f}", + f"{data['mfu']:.2f}", + f"{data['peak_memory']:.2f}", + ) + + if is_main_process(): + console.section_title("Benchmark Results") + console.print(table) + + def main() -> None: """Main entry point for training script.""" diff --git a/hud/rl/utils.py b/hud/rl/utils.py index 366633db..76d112f0 100644 --- a/hud/rl/utils.py +++ b/hud/rl/utils.py @@ -1,9 +1,11 @@ from pathlib import Path import json +import subprocess from typing import Dict import os import torch import torch.distributed as dist +from hud.rl.logger import console def get_weights_path(output_dir: str | Path, step: int) -> Path: @@ -50,3 +52,63 @@ def is_main_process() -> bool: if not dist.is_initialized(): return True return dist.get_rank() == 0 + +# source: https://github.com/pytorch/torchtitan/blob/main/torchtitan/tools/utils.py#L68 +# hardcoded BF16 type peak flops for NVIDIA A100, H100, H200, B200 GPU and AMD MI250, MI300X, AMD MI325X and Intel PVC +def get_peak_flops(device_name: str) -> float: + try: + # Run the lspci command and capture the output + result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True) + # Filter the output for lines containing both "NVIDIA" and "H100" + filtered_lines = [ + line + for line in result.stdout.splitlines() + if "NVIDIA" in line and "H100" in line + ] + # Join all filtered lines into a single string + device_name = " ".join(filtered_lines) or device_name + except FileNotFoundError as e: + console.warning(f"Error running lspci: {e}, fallback to use device_name") + if "A100" in device_name: + # data from https://www.nvidia.com/en-us/data-center/a100/ + return 312e12 + elif "H100" in device_name: + # data from https://www.nvidia.com/en-us/data-center/h100/ + # NOTE: Specifications are one-half lower without sparsity. + if "NVL" in device_name: + return 835e12 + elif "PCIe" in device_name: + return 756e12 + else: # for H100 SXM and other variants + return 989e12 + elif "H200" in device_name: + # data from https://www.nvidia.com/en-us/data-center/h200/ + return 989e12 + elif "B200" in device_name: + # data from https://nvdam.widen.net/s/wwnsxrhm2w/blackwell-datasheet-3384703 + return 2.25e15 + elif "MI300X" in device_name or "MI325X" in device_name: + # MI300X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi300x.html + # MI325X data from https://www.amd.com/en/products/accelerators/instinct/mi300/mi325x.html + return 1300e12 + elif "MI250X" in device_name: + # data from https://www.amd.com/en/products/accelerators/instinct/mi200/mi250x.html (per GCD) + return 191.5e12 + elif "Data Center GPU Max 1550" in device_name: + # Also known as Ponte Vecchio (PVC). + # data from https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html + # Dot Product Accumulate Systolic (DPAS): + # - Freq: 1300MHz + # - #ops: 512 + # Full EU mode (i.e. 512 max compute units): 340.8 TFLOPS (BF16) + # Standard EU mode (i.e. 448 max compute units): 298.2 TFLOPS (BF16) + max_comp_units = torch.xpu.get_device_properties("xpu").max_compute_units + return 512 * max_comp_units * 1300 * 10**6 + elif "l40s" in device_name: + # data from: "https://resources.nvidia.com/en-us-l40s/l40s-datasheet-28413" + return 362e12 + + else: # for other GPU types, assume A100 + console.warning(f"Peak flops undefined for: {device_name}, fallback to A100") + return 312e12 +