diff --git a/nemo_rl/models/megatron/setup.py b/nemo_rl/models/megatron/setup.py index 6a58d930d5..25bbc66f02 100644 --- a/nemo_rl/models/megatron/setup.py +++ b/nemo_rl/models/megatron/setup.py @@ -468,17 +468,17 @@ def _apply_performance_config(model_cfg: Any, config: PolicyConfig) -> None: # Fusion settings model_cfg.apply_rope_fusion = config["megatron_cfg"]["apply_rope_fusion"] model_cfg.bias_activation_fusion = config["megatron_cfg"]["bias_activation_fusion"] - # Optional explicit attention backend override for environments where - # TE auto backend probing is unstable. + # Attention backend configuration attention_backend = config["megatron_cfg"].get("attention_backend") if attention_backend is not None: - if isinstance(attention_backend, str): + for _nvte_var in ("NVTE_FUSED_ATTN", "NVTE_FLASH_ATTN", "NVTE_UNFUSED_ATTN"): + os.environ.pop(_nvte_var, None) + try: model_cfg.attention_backend = AttnBackend[attention_backend] - elif isinstance(attention_backend, int): - model_cfg.attention_backend = AttnBackend(attention_backend) - else: + except KeyError: raise ValueError( - f"Unsupported {type(attention_backend)=}, expected str or int" + f"Invalid attention backend: {attention_backend}. " + f"Available backends are: {list(AttnBackend.__members__.keys())}" ) # FP8 configuration diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index 3636e5ac64..73edc7a197 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -200,6 +200,9 @@ class MegatronConfig(TypedDict): bias_activation_fusion: bool # Force overwrite of the initial checkpoint even if it exists (default: False) force_overwrite_initial_ckpt: NotRequired[bool] + # Attention backend available values: + # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/enums.py#L60 + attention_backend: NotRequired[str] moe_per_layer_logging: bool # Set to true to enable DeepEP for expert parallel communication # Must set moe_token_dispatcher_type to 'flex' diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index eff97b215c..09adb3811f 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -65,6 +65,7 @@ def create_megatron_test_config( converter_type: str = "LlamaForCausalLM", logprob_chunk_size: Optional[int] = None, defer_fp32_logits: Optional[bool] = None, + attention_backend: Optional[str] = None, ) -> PolicyConfig: """Create a test config for Megatron policy worker.""" return { @@ -179,6 +180,7 @@ def create_megatron_test_config( "fp8_recipe": "tensorwise", "fp8_param": True, }, + "attention_backend": attention_backend, }, "make_sequence_length_divisible_by": tp, "optimizer": None, # Remove default FSDP optimizer @@ -316,6 +318,10 @@ def training_setup(request): config["megatron_cfg"]["sequence_parallel"] = config_updates[ "sequence_parallel" ] + if "attention_backend" in config_updates: + config["megatron_cfg"]["attention_backend"] = config_updates[ + "attention_backend" + ] tokenizer = get_tokenizer(config["tokenizer"]) config["generation"] = configure_generation_config( @@ -376,15 +382,10 @@ def training_setup(request): (2, 1, 1, "tiny_qwen2_model_path", {}), (2, 2, 1, "tiny_qwen2_model_path", {}), (2, 1, 1, "tiny_llama_model_path", {"precision": "bfloat16"}), - ( - 2, - 1, - 1, - "tiny_llama_model_path", - {"activation_checkpointing": True}, - ), + (2, 1, 1, "tiny_llama_model_path", {"activation_checkpointing": True}), (2, 2, 1, "tiny_llama_model_path", {"sequence_parallel": True}), (2, 2, 1, "tiny_llama_model_path", {"precision": "bfloat16", "fp8": "hybrid"}), + (2, 1, 1, "tiny_llama_model_path", {"attention_backend": "flash", "precision": "bfloat16"}), ], indirect=True, ids=[ @@ -396,6 +397,7 @@ def training_setup(request): "2gpu_dp2_llama_ac", "2gpu_tp2_llama_sp", "2gpu_tp2_llama_fp8", + "2gpu_dp2_llama_attention_backend_flash", ], ) def test_megatron_policy_training(training_setup):