diff --git a/nemo_rl/algorithms/dpo.py b/nemo_rl/algorithms/dpo.py index 30ba78f6f2..bd43dce528 100644 --- a/nemo_rl/algorithms/dpo.py +++ b/nemo_rl/algorithms/dpo.py @@ -420,15 +420,16 @@ def dpo_train( with timer.time("total_step_time"): print("▶ Taking a training step...") - train_results = policy.train( - batch, - loss_fn, - eval_mode=False, - ## NOTE: we double the batch size here because each preference example corresponds to a pair of - ## examples, chosen and rejected, and the pair needs to be processed as part of the same microbatch. - gbs=master_config["policy"]["train_global_batch_size"] * 2, - mbs=master_config["policy"]["train_micro_batch_size"] * 2, - ) + with timer.time("policy_training"): + train_results = policy.train( + batch, + loss_fn, + eval_mode=False, + ## NOTE: we double the batch size here because each preference example corresponds to a pair of + ## examples, chosen and rejected, and the pair needs to be processed as part of the same microbatch. + gbs=master_config["policy"]["train_global_batch_size"] * 2, + mbs=master_config["policy"]["train_micro_batch_size"] * 2, + ) is_last_step = total_steps + 1 >= master_config["dpo"][ "max_num_steps" @@ -526,6 +527,22 @@ def dpo_train( print("\n📊 Training Results:") print(f" • Loss: {float(metrics['loss']):.4f}") + if "total_flops" in train_results: + total_tflops = ( + train_results["total_flops"] + / timing_metrics["policy_training"] + / 1e12 + ) + num_ranks = train_results["num_ranks"] + print( + f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)" + ) + if "theoretical_tflops" in train_results: + theoretical_tflops = train_results["theoretical_tflops"] + print( + f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%" + ) + metrics["train_fp_utilization"] = total_tflops / theoretical_tflops print("\n⏱️ Timing:") # Display total time first, separately total_time = timing_metrics.get("total_step_time", 0) diff --git a/nemo_rl/algorithms/grpo.py b/nemo_rl/algorithms/grpo.py index fceb2173c6..9dc5670c63 100644 --- a/nemo_rl/algorithms/grpo.py +++ b/nemo_rl/algorithms/grpo.py @@ -806,6 +806,20 @@ def grpo_train( print( f" • Mean Generation Length: {rollout_metrics['mean_gen_tokens_per_sample']:.4f}" ) + if "total_flops" in train_results: + total_tflops = ( + train_results["total_flops"] / timing_metrics["policy_training"] / 1e12 + ) + num_ranks = train_results["num_ranks"] + print( + f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)" + ) + if "theoretical_tflops" in train_results: + theoretical_tflops = train_results["theoretical_tflops"] + print( + f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%" + ) + metrics["train_fp_utilization"] = total_tflops / theoretical_tflops print("\n⏱️ Timing:") # Display total time first, separately diff --git a/nemo_rl/algorithms/sft.py b/nemo_rl/algorithms/sft.py index 804909c2c4..7787220af1 100644 --- a/nemo_rl/algorithms/sft.py +++ b/nemo_rl/algorithms/sft.py @@ -405,7 +405,8 @@ def sft_train( ) print("▶ Taking a training step...") - train_results = policy.train(train_data, loss_fn) + with timer.time("policy_training"): + train_results = policy.train(train_data, loss_fn) is_last_step = total_steps + 1 >= master_config["sft"][ "max_num_steps" @@ -502,6 +503,22 @@ def sft_train( print("\n📊 Training Results:") print(f" • Loss: {float(metrics['loss']):.4f}") + if "total_flops" in train_results: + total_tflops = ( + train_results["total_flops"] + / timing_metrics["policy_training"] + / 1e12 + ) + num_ranks = train_results["num_ranks"] + print( + f" • Training FLOPS: {total_tflops:.2f} TFLOPS ({total_tflops / num_ranks:.2f} TFLOPS per rank)" + ) + if "theoretical_tflops" in train_results: + theoretical_tflops = train_results["theoretical_tflops"] + print( + f" • Training Model Floating Point Utilization: {100 * total_tflops / theoretical_tflops:.2f}%" + ) + metrics["train_fp_utilization"] = total_tflops / theoretical_tflops print("\n⏱️ Timing:") # Display total time first, separately total_time = timing_metrics.get("total_step_time", 0) diff --git a/nemo_rl/models/policy/dtensor_policy_worker.py b/nemo_rl/models/policy/dtensor_policy_worker.py index 8c0198784b..18ae23d95b 100644 --- a/nemo_rl/models/policy/dtensor_policy_worker.py +++ b/nemo_rl/models/policy/dtensor_policy_worker.py @@ -861,6 +861,8 @@ def train( "global_loss": global_loss.cpu(), "grad_norm": grad_norm, "rank": torch.distributed.get_rank(), + "gpu_name": torch.cuda.get_device_name(), + "model_dtype": self.dtype, "all_mb_metrics": dict(mb_metrics), } diff --git a/nemo_rl/models/policy/lm_policy.py b/nemo_rl/models/policy/lm_policy.py index ebc608e35d..79634570fc 100644 --- a/nemo_rl/models/policy/lm_policy.py +++ b/nemo_rl/models/policy/lm_policy.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import warnings from collections import defaultdict from typing import Any, Optional, Union @@ -41,6 +42,11 @@ LogprobOutputSpec, ReferenceLogprobOutputSpec, ) +from nemo_rl.utils.flops_tracker import ( + FLOPTracker, + get_default_hf_config, + get_theoretical_tflops, +) PathLike = Union[str, "os.PathLike[Any]"] @@ -147,6 +153,15 @@ def __init__( else: self.use_dynamic_batches = False + # initialize FLOPs tracker + try: + self.flops_tracker = FLOPTracker.from_config( + config["model_name"], get_default_hf_config(config["model_name"]) + ) + except ValueError as e: + self.flops_tracker = None + print(f"FLOPS tracker not supported for model {config['model_name']}: {e}") + if config["sequence_packing"]["enabled"]: self.use_sequence_packing = True self.sequence_packing_args: SequencePackingArgs = { @@ -346,6 +361,12 @@ def train( batch_size=batch_size, ) + if self.flops_tracker is not None: + self.flops_tracker.reset() + for shard in sharded_data: + input_lengths = shard["input_lengths"] + self.flops_tracker.track_batch(input_lengths.tolist()) + # Train each shard in parallel futures = self.worker_group.run_all_workers_sharded_data( "train", @@ -376,6 +397,18 @@ def train( "grad_norm": results[0]["grad_norm"], } + if self.flops_tracker is not None: + aggregated_results["total_flops"] = self.flops_tracker.total_flops + aggregated_results["num_ranks"] = len(results) + + try: + aggregated_results["theoretical_tflops"] = sum( + get_theoretical_tflops(r["gpu_name"], r["model_dtype"]) + for r in results + ) + except Exception as e: + warnings.warn(f"Error getting theoretical flops: {e}") + # Aggregate metrics across all workers all_mb_metrics = defaultdict(list) for r in results: diff --git a/nemo_rl/models/policy/megatron_policy_worker.py b/nemo_rl/models/policy/megatron_policy_worker.py index aa1f349b2b..35c18eb701 100644 --- a/nemo_rl/models/policy/megatron_policy_worker.py +++ b/nemo_rl/models/policy/megatron_policy_worker.py @@ -979,6 +979,8 @@ def train( metrics = { "global_loss": global_loss.cpu(), "rank": torch.distributed.get_rank(), + "gpu_name": torch.cuda.get_device_name(), + "model_dtype": self.dtype, "all_mb_metrics": dict(mb_metrics), "grad_norm": torch.tensor( mb_metrics["grad_norm"][-1] diff --git a/nemo_rl/utils/flops_formulas.py b/nemo_rl/utils/flops_formulas.py new file mode 100644 index 0000000000..4bbddc39e8 --- /dev/null +++ b/nemo_rl/utils/flops_formulas.py @@ -0,0 +1,553 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import List, Optional, Union + + +# lifted from NeMo/nemo/utils/flops_formulas.py +@dataclass +class FLOPSConfig: + """Contains the model hparams needed for FLOPS computations.""" + + gbs: int + enc_seq_len: Optional[int] = None + hs: Optional[int] = None + layers: Optional[int] = None + ffn_hs: Optional[int] = None + attention_heads: Optional[int] = None + moe_router_topk: Optional[int] = None + query_groups: Optional[int] = None + img_seq_len: Optional[int] = None + img_h: Optional[int] = None + img_w: Optional[int] = None + in_channels: Optional[int] = None + patch_dim: Optional[int] = None + class_token_len: Optional[int] = None + projector_type: Optional[str] = None + inp_s: Optional[int] = None + model_pattern: Optional[str] = None + vocab_size: Optional[int] = None + model_channels: Optional[int] = None + vec_in_dim: Optional[int] = None + q_lora_rank: Optional[int] = None + kv_lora_rank: Optional[int] = None + qk_head_dim: Optional[int] = None + qk_pos_emb_head_dim: Optional[int] = None + v_head_dim: Optional[int] = None + moe_layer_freq: Optional[Union[int, List[int]]] = None + moe_shared_expert_intermediate_size: Optional[int] = None + moe_ffn_hidden_size: Optional[int] = None + mtp_num_layers: Optional[int] = None + causal_self_attn: Optional[bool] = None + is_hybrid_model: bool = False + hybrid_override_pattern: Optional[str] = None + mamba_state_dim: Optional[int] = None + mamba_head_dim: Optional[int] = None + mamba_num_groups: Optional[int] = None + mamba_num_heads: Optional[int] = None + + +def gpt3(config: FLOPSConfig): + """Model FLOPs for GPT3 family.""" + return ( + 24 * config.gbs * config.enc_seq_len * config.hs * config.hs + + 4 * config.gbs * config.enc_seq_len * config.enc_seq_len * config.hs + ) * (3 * config.layers) + ( + 6 * config.gbs * config.enc_seq_len * config.hs * config.vocab_size + ) + + +def llama2(config: FLOPSConfig): + """Model FLOPs for llama2 family.""" + return ( + config.gbs + * config.enc_seq_len + * config.layers + * config.hs + * config.hs + * ( + 12 + + (12 * config.query_groups / config.attention_heads) + + (18 * config.ffn_hs / config.hs) + + (12 * config.enc_seq_len / config.hs) + + (6 * config.vocab_size / (config.layers * config.hs)) + ) + ) + + +def llama3(config: FLOPSConfig): + """Model FLOPs for llama3 family.""" + return ( + config.gbs + * config.enc_seq_len + * config.layers + * config.hs + * config.hs + * ( + 12 + + (12 * config.query_groups / config.attention_heads) + + (18 * config.ffn_hs / config.hs) + + (12 * config.enc_seq_len / config.hs) + + (6 * config.vocab_size / (config.layers * config.hs)) + ) + ) + + +def nemotron(config: FLOPSConfig): + """Model FLOPs for nemotron family.""" + return ( + config.gbs + * config.enc_seq_len + * config.layers + * config.hs + * config.hs + * ( + 12 + + (12 * config.query_groups / config.attention_heads) + + (12 * config.ffn_hs / config.hs) + + (12 * config.enc_seq_len / config.hs) + + (6 * config.vocab_size / (config.layers * config.hs)) + ) + ) + + +def mixtral(config: FLOPSConfig): + """Model FLOPs for mixtral family.""" + return ( + config.gbs + * config.enc_seq_len + * config.layers + * config.hs + * config.hs + * ( + 12 + + (12 * config.query_groups / config.attention_heads) + + (18 * config.moe_router_topk * config.ffn_hs / config.hs) + + (12 * config.enc_seq_len / config.hs) + + (6 * config.vocab_size / (config.layers * config.hs)) + ) + ) + + +def qwen2(config: FLOPSConfig): + """Model FLOPs for Qwen2 family.""" + causal_self_attn = True + seq_len = config.enc_seq_len + hidden_size = config.hs + gated_linear_multiplier = 2 + + # attention flops for GQA + attention_flops = ( + 3 + * 2 + * config.gbs + * config.layers + * seq_len + * hidden_size + * hidden_size + * ( + (2 + 1) # QKV gemm + + ( + seq_len / hidden_size * 2 * (0.5 if causal_self_attn else 1) + ) # attention + + 1 # attention proj gemm + ) + ) + + # mlp flops + mlp_flops = ( + 3 + * 2 + * config.gbs + * config.layers + * seq_len + * hidden_size + * (1 + gated_linear_multiplier) + * config.ffn_hs + ) + + # vocab flops + vocab_flops = 3 * 2 * config.gbs * seq_len * hidden_size * config.vocab_size + + return attention_flops + mlp_flops + vocab_flops + + +def qwen3(config: FLOPSConfig): + """Model FLOPs for Qwen3 family.""" + causal_self_attn = True + seq_len = config.enc_seq_len + hidden_size = config.hs + gated_linear_multiplier = 2 + + # attention flops for GQA + attention_flops = ( + 3 + * 2 + * config.gbs + * config.layers + * seq_len + * hidden_size + * hidden_size + * ( + (config.query_groups / config.attention_heads * 2 + 1) # QKV gemm + + ( + seq_len / hidden_size * 2 * (0.5 if causal_self_attn else 1) + ) # attention + + 1 # attention proj gemm + ) + ) + + # mlp flops + mlp_flops = ( + 3 + * 2 + * config.gbs + * config.layers + * seq_len + * hidden_size + * (1 + gated_linear_multiplier) + * (config.moe_ffn_hidden_size * config.moe_router_topk) # MoE layers + ) + + # vocab flops + vocab_flops = 3 * 2 * config.gbs * seq_len * hidden_size * config.vocab_size + + return attention_flops + mlp_flops + vocab_flops + + +def bert(config: FLOPSConfig): + """Model FLOPs for BERT family.""" + return ( + 72 + * config.gbs + * config.layers + * config.enc_seq_len + * config.hs + * config.hs + * ( + 1 + + (config.enc_seq_len / (6 * config.hs)) + + (config.vocab_size / (12 * config.hs * config.layers)) + ) + ) + + +def transformer(config: FLOPSConfig): + """Calculate FLOPs for a standard Transformer model. + + Note: This does not cover encoder-decoder models. + """ + # Extract parameters from config + batch_size = config.gbs + hidden_size = config.hs + seq_length = config.enc_seq_len + num_layers = config.layers + num_attention_heads = config.attention_heads + ffn_hidden_size = config.ffn_hs + vocab_size = config.vocab_size + + if vocab_size is None: + raise ValueError("vocab_size is required for transformer FLOPs calculation") + + # Handle optional parameters with reasonable defaults + query_groups = ( + config.query_groups if config.query_groups is not None else num_attention_heads + ) + causal_self_attn = ( + config.causal_self_attn if config.causal_self_attn is not None else False + ) + moe_router_topk = ( + config.moe_router_topk if config.moe_router_topk is not None else 0 + ) + kv_channels = hidden_size // num_attention_heads # Standard dimension per head + + # Calculate query projection size and ratio + query_projection_size = kv_channels * num_attention_heads + query_projection_to_hidden_size_ratio = query_projection_size / hidden_size + + # MoE parameters - simplified for NeMo config + # In this implementation, we assume all layers are dense if num_experts is None + if moe_router_topk == 0: + num_dense_layers = num_layers + num_moe_layers = 0 + num_experts_routed_to = 0 + else: + # Simplified MoE handling - assuming uniform distribution of MoE layers + # This can be expanded based on NeMo's actual MoE implementation + num_moe_layers = num_layers // 2 # Simplified assumption + num_dense_layers = num_layers - num_moe_layers + num_experts_routed_to = moe_router_topk + + # Handle SwiGLU vs standard GELU/ReLU + # Default to standard activation (no SwiGLU) + gated_linear_multiplier = 1 + + # Define the expansion factor as described in the paper + # 3x: Each GEMM needs forward pass, backward wgrad, and backward dgrad + # 2x: GEMMs are stacked twice in standard Transformer architectures + # 2x: A GEMM of m*n with n*k requires 2mnk floating-point operations + expansion_factor = 3 * 2 * 2 + # Attention + if not causal_self_attn: + attention_component = ( + 1 + + (query_groups / num_attention_heads) + # Only half of the attention matrix is non-zero and needs to be multiplied with V + + (seq_length / hidden_size) # If causal self attn -> divide by 2. + ) * query_projection_to_hidden_size_ratio + else: + attention_component = ( + 1 + + (query_groups / num_attention_heads) + # Only half of the attention matrix is non-zero and needs to be multiplied with V + + (seq_length / hidden_size / 2) # If causal self attn -> divide by 2. + ) * query_projection_to_hidden_size_ratio + + # Calculate total FLOPs + total_flops = ( + expansion_factor + * batch_size + * seq_length + * num_layers + * hidden_size + * hidden_size + * ( + attention_component + # MLP component + + ( + ( + # Dense layers + (ffn_hidden_size * num_dense_layers) + + + # MoE layers + ( + ( + # Routed experts + ffn_hidden_size * num_experts_routed_to + # Note: Shared experts are not implemented in this version + ) + * num_moe_layers + ) + ) + * gated_linear_multiplier + / (num_layers * hidden_size) + ) + # Logit component + + (vocab_size / (2 * num_layers * hidden_size)) + ) + ) + + return total_flops + + +def flux(config: FLOPSConfig): + """Model FLOPs for FLUX.""" + hs = config.hs + seq_len = config.model_channels + config.inp_s + base_factor = 6 * config.gbs # common multiplier for most terms + + # Joint layer computations + joint_layer_flops = ( + base_factor + * config.layers[0] + * ( + 10 * hs * hs # hidden size operations + + 2 + * hs + * (config.model_channels + config.inp_s) + * (1 + hs * 7) # channel and context joint attention + + 2 * (config.model_channels + config.inp_s) * hs # final projection + ) + ) + + # Single layer computations + single_layer_flops = ( + base_factor + * config.layers[1] + * seq_len + * hs + * ( + 3 # linear Y + + 1 # Modulation + + 4 * hs # Linear computations + + (3 * hs + 2 * seq_len) # attention operations + + 5 * hs # feed-forward + + 1 # Modulation + ) + ) + + # Embedding and projection layers + other_flops = base_factor * ( + config.inp_s * config.in_channels * hs # image embedding + + config.inp_s * hs * config.model_channels # text embedding + + config.vec_in_dim * hs + + hs * hs # vector embedding + + 2 * (config.model_channels * hs + hs * hs) # guidance + timestep embedding + + (config.inp_s * config.in_channels * hs) / config.gbs # final projection + ) + + return joint_layer_flops + single_layer_flops + other_flops + + +def deepseekv3(config: FLOPSConfig): + """Model FLOPs for DeepSeek V3.""" + # self-attention flops + bmm1_flops = ( + 0.5 + * (config.qk_head_dim + config.qk_pos_emb_head_dim) + * config.attention_heads + * (config.enc_seq_len**2) + ) + bmm2_flops = ( + 0.5 * config.v_head_dim * config.attention_heads * (config.enc_seq_len**2) + ) + per_input_attention_flops = 6 * (bmm1_flops + bmm2_flops) * config.layers + if config.mtp_num_layers is not None: + per_input_attention_flops += ( + 6 * (bmm1_flops + bmm2_flops) * config.mtp_num_layers + ) + + # linear layer flops + per_layer_mla_params = config.hs * config.q_lora_rank + config.q_lora_rank * ( + (config.qk_head_dim + config.qk_pos_emb_head_dim) * config.attention_heads + ) # Q + per_layer_mla_params += config.hs * config.qk_pos_emb_head_dim # K^R + per_layer_mla_params += config.hs * config.kv_lora_rank + config.kv_lora_rank * ( + (config.qk_head_dim + config.v_head_dim) * config.attention_heads + ) # K^C and V^C + per_layer_mla_params += ( + config.v_head_dim * config.attention_heads * config.hs + ) # Proj + mla_params = per_layer_mla_params * config.layers + if config.mtp_num_layers is not None: + mla_params += per_layer_mla_params * config.mtp_num_layers + + dense_layer_ffn_params = config.hs * config.ffn_hs * 3 # gated linear unit + per_shared_expert_params = ( + config.hs * config.moe_shared_expert_intermediate_size * 3 + ) + per_selected_expert_params = config.hs * config.moe_ffn_hidden_size * 3 + ffn_params = 0 + + if isinstance(config.moe_layer_freq, int): + moe_layer_pattern = [ + 1 if (i % config.moe_layer_freq == 0) else 0 for i in range(config.layers) + ] + else: + moe_layer_pattern = config.moe_layer_freq + for i in moe_layer_pattern: + if i == 0: + ffn_params += dense_layer_ffn_params + else: + ffn_params += per_shared_expert_params + ( + per_selected_expert_params * config.moe_router_topk + ) + if config.mtp_num_layers is not None: + for i in range(config.mtp_num_layers): + ffn_params += per_shared_expert_params + ( + per_selected_expert_params * config.moe_router_topk + ) + per_input_params = mla_params + ffn_params + per_input_linear_flops = 6 * per_input_params * config.enc_seq_len + + # vocab flops + per_input_vocab_flops = 6 * config.vocab_size * config.hs * config.enc_seq_len + if config.mtp_num_layers is not None: + for i in range(config.mtp_num_layers): + per_input_vocab_flops += ( + 6 * config.vocab_size * config.hs * config.enc_seq_len + ) + per_input_vocab_flops += 6 * config.hs * 2 * config.hs * config.enc_seq_len + + return ( + per_input_attention_flops + per_input_linear_flops + per_input_vocab_flops + ) * config.gbs + + +def _mlp_layer_flops(config: FLOPSConfig): + """Model FLOPs for MLP layer.""" + return ( + 6 + * config.gbs + * config.enc_seq_len + * config.hs + * config.ffn_hs + * (2 if config.gated_linear_unit else 1) + ) + + +def _non_mla_attn_layer_flops(config: FLOPSConfig): + """Model FLOPs for attention layer.""" + return ( + 6 + * config.gbs + * config.enc_seq_len + * config.hs + * ( + config.hs # Q + + config.query_groups / config.attention_heads * config.hs * 2 # KV + + config.enc_seq_len / 2 * 2 + + config.hs + ) + ) + + +def _mamba_layer_flops(config: FLOPSConfig): + """Model FLOPs for Mamba layer. We ignore part of the flops of scan because the chunk size is not known from model config.""" + assert config.mamba_state_dim is not None + assert config.mamba_head_dim is not None + + if config.mamba_num_heads: + nheads = config.mamba_num_heads + else: + nheads = 2 * config.hs // config.mamba_head_dim # default expand is 2 + d_in = nheads * config.mamba_head_dim + return ( + ( + 6 + * config.gbs + * config.enc_seq_len + * config.hs + * (2 * d_in + 2 * config.mamba_num_groups * config.mamba_state_dim + nheads) + ) + + (3 * 2 * config.gbs * config.enc_seq_len * d_in * config.mamba_state_dim) + + (6 * config.gbs * config.enc_seq_len * d_in * config.hs) + ) + + +def _hybrid_model_flops(config: FLOPSConfig): + """Model FLOPs for hybrid model.""" + assert config.is_hybrid_model == True + assert config.hybrid_override_pattern is not None + + num_attn_layers, num_mamba_layers, num_mlp_layers = 0, 0, 0 + for c in config.hybrid_override_pattern: + if c == "M": + num_mamba_layers += 1 + elif c == "-": + num_mlp_layers += 1 + elif c == "*": + num_attn_layers += 1 + return ( + num_attn_layers * _non_mla_attn_layer_flops(config) + + num_mamba_layers * _mamba_layer_flops(config) + + num_mlp_layers * _mlp_layer_flops(config) + + 6 * config.gbs * config.enc_seq_len * config.hs * config.vocab_size + ) + + +def nemotronh(config: FLOPSConfig): + """Model FLOPs for NemotronH.""" + return _hybrid_model_flops(config) diff --git a/nemo_rl/utils/flops_tracker.py b/nemo_rl/utils/flops_tracker.py new file mode 100644 index 0000000000..aab769fdaf --- /dev/null +++ b/nemo_rl/utils/flops_tracker.py @@ -0,0 +1,142 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import asdict +from typing import Callable, Optional + +import torch +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.models.llama.configuration_llama import LlamaConfig +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.models.qwen3.configuration_qwen3 import Qwen3Config + +from nemo_rl.models.policy.utils import sliding_window_overwrite +from nemo_rl.utils.flops_formulas import FLOPSConfig, llama2, llama3, qwen2, qwen3 + + +def get_default_hf_config(model_name: str) -> PretrainedConfig: + """Get the default Hugging Face config for a model. + + Both the DTensor and MCore paths use the same default config, we initialize the model config + here to allow computation of theoretical flops which is agnostic to the backend. + """ + return AutoConfig.from_pretrained( + model_name, + torch_dtype=torch.float32, + trust_remote_code=True, + **sliding_window_overwrite(model_name), + ) + + +def convert_config_to_flops_config( + model_name: str, config: PretrainedConfig +) -> tuple[FLOPSConfig, Callable]: + """Convert a pretrained config to a tuple containing a FLOPSConfig and a flops formula.""" + if isinstance(config, Qwen2Config): + return FLOPSConfig( + gbs=0, + hs=config.hidden_size, + layers=config.num_hidden_layers, + ffn_hs=config.intermediate_size, + vocab_size=config.vocab_size, + ), qwen2 + elif isinstance(config, Qwen3Config): + return FLOPSConfig( + gbs=0, + hs=config.hidden_size, + layers=config.num_hidden_layers, + ffn_hs=config.intermediate_size, + vocab_size=config.vocab_size, + query_groups=config.num_attention_heads / config.num_key_value_heads, + attention_heads=config.num_attention_heads, + # for non-MoE models, we use the intermediate size as the ffn hidden size + moe_ffn_hidden_size=config.intermediate_size, + moe_router_topk=1, + ), qwen3 + elif isinstance(config, LlamaConfig): + return FLOPSConfig( + gbs=0, + hs=config.hidden_size, + layers=config.num_hidden_layers, + ffn_hs=config.intermediate_size, + query_groups=config.num_attention_heads / config.num_key_value_heads, + attention_heads=config.num_attention_heads, + vocab_size=config.vocab_size, + ), llama3 if "llama3" in model_name.lower() else llama2 + else: + raise ValueError(f"Unsupported config type: {type(config)}") + + +THEORETICAL_TFLOPS = { + ("NVIDIA H100 80GB HBM3", torch.bfloat16): 1979 / 2, + ("NVIDIA H100 80GB HBM3", torch.float32): 67.0, +} + + +def get_theoretical_tflops(device_name: str, model_dtype: torch.dtype) -> float: + """Get the theoretical total flops for a device name.""" + if (device_name, model_dtype) in THEORETICAL_TFLOPS: + return THEORETICAL_TFLOPS[(device_name, model_dtype)] + else: + raise ValueError( + f"Unknown device name: {device_name} and dtype name: {model_dtype}" + ) + + +class FLOPTracker: + def __init__( + self, + model_name: str, + base_config: FLOPSConfig | None = None, + flops_formula: Callable[[FLOPSConfig], float] | None = None, + ): + self.model_name = model_name + self.base_config = base_config + self.total_flops = 0 + self.flops_formula: Optional[Callable[[FLOPSConfig], float]] = flops_formula + + @classmethod + def from_config(cls, model_name: str, config: PretrainedConfig) -> "FLOPTracker": + flops_config, flops_formula = convert_config_to_flops_config(model_name, config) + return cls( + model_name=model_name, base_config=flops_config, flops_formula=flops_formula + ) + + def track(self, n_samples: int, padded_seq_len: int): + if self.flops_formula is None: + raise ValueError("Flops formula is not set") + + base_config_dict = ( + asdict(self.base_config) if self.base_config is not None else {} + ) + + # Override gbs and enc_seq_len with current values + config_dict = { + **base_config_dict, + "gbs": n_samples, + "enc_seq_len": padded_seq_len, + } + + # Compute and accumulate flops + flops = self.flops_formula(FLOPSConfig(**config_dict)) + self.total_flops += flops + + def track_batch(self, sequence_lengths: list[int]): + """Track the flops for a batch of sequences.""" + for seq_len in sequence_lengths: + self.track(n_samples=1, padded_seq_len=seq_len) + + def reset(self): + self.total_flops = 0 diff --git a/tests/unit/models/policy/test_dtensor_worker.py b/tests/unit/models/policy/test_dtensor_worker.py index ba33408a8d..ab876eb214 100644 --- a/tests/unit/models/policy/test_dtensor_worker.py +++ b/tests/unit/models/policy/test_dtensor_worker.py @@ -377,6 +377,28 @@ def verify_loss_tensor(loss_tensor): # Verify loss changed between iterations (model parameters were updated) assert losses[0] > losses[-1], "Loss should decrease over training iterations" + # Verify the train function returns the performance metrics + + if policy.flops_tracker is not None: + assert "total_flops" in results and isinstance( + results["total_flops"], (int, float) + ), "training backend should report total_flops" + assert results["total_flops"] > 0, "total_flops should be positive" + assert "num_ranks" in results and isinstance(results["num_ranks"], int), ( + "training backend should report num_ranks" + ) + assert results["num_ranks"] > 0, "num_ranks should be positive" + + # we don't always require theoretical_tflops since the data about the GPU + # is not always available. + if "theoretical_tflops" in results: + assert isinstance(results["theoretical_tflops"], (int, float)), ( + "training backend should report theoretical_tflops" + ) + assert results["theoretical_tflops"] > 0, ( + "theoretical_tflops should be positive" + ) + @pytest.fixture def logprob_setup(request, two_gpu_virtual_cluster): diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index 38607ba59f..20c31324e0 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -387,6 +387,26 @@ def verify_loss_tensor(loss_tensor): # Verify loss changed between iterations (model parameters were updated) assert losses[0] > losses[-1], "Loss should decrease over training iterations" + if policy.flops_tracker is not None: + assert "total_flops" in results and isinstance( + results["total_flops"], (int, float) + ), "training backend should report total_flops" + assert results["total_flops"] > 0, "total_flops should be positive" + assert "num_ranks" in results and isinstance(results["num_ranks"], int), ( + "training backend should report num_ranks" + ) + assert results["num_ranks"] > 0, "num_ranks should be positive" + + # we don't always require theoretical_tflops since the data about the GPU + # is not always available. + if "theoretical_tflops" in results: + assert "theoretical_tflops" in results and isinstance( + results["theoretical_tflops"], (int, float) + ), "training backend should report theoretical_tflops" + assert results["theoretical_tflops"] > 0, ( + "theoretical_tflops should be positive" + ) + @pytest.fixture def generation_setup(request, tiny_llama_model_path):