diff --git a/examples/configs/distillation_math.yaml b/examples/configs/distillation_math.yaml index 62937754f1..9e32814a62 100644 --- a/examples/configs/distillation_math.yaml +++ b/examples/configs/distillation_math.yaml @@ -107,6 +107,9 @@ policy: &POLICY_BASE bias_activation_fusion: True defer_fp32_logits: False moe_per_layer_logging: False + moe_enable_deepep: false + moe_token_dispatcher_type: "allgather" + moe_shared_expert_overlap: false optimizer: optimizer: "adam" diff --git a/examples/configs/distillation_math_megatron.yaml b/examples/configs/distillation_math_megatron.yaml index 644d240a7b..04802ac0ca 100644 --- a/examples/configs/distillation_math_megatron.yaml +++ b/examples/configs/distillation_math_megatron.yaml @@ -59,6 +59,9 @@ policy: &POLICY_BASE bias_activation_fusion: True moe_per_layer_logging: False defer_fp32_logits: False + moe_enable_deepep: false + moe_token_dispatcher_type: "allgather" + moe_shared_expert_overlap: false optimizer: optimizer: "adam" diff --git a/examples/configs/dpo.yaml b/examples/configs/dpo.yaml index a5e4f037af..fcfc22e1bd 100755 --- a/examples/configs/dpo.yaml +++ b/examples/configs/dpo.yaml @@ -119,6 +119,9 @@ policy: bias_activation_fusion: True defer_fp32_logits: False moe_per_layer_logging: False + moe_enable_deepep: false + moe_token_dispatcher_type: "allgather" + moe_shared_expert_overlap: false optimizer: optimizer: "adam" diff --git a/examples/configs/grpo_math_1B.yaml b/examples/configs/grpo_math_1B.yaml index ec8c2c5ecc..52de51905c 100644 --- a/examples/configs/grpo_math_1B.yaml +++ b/examples/configs/grpo_math_1B.yaml @@ -116,6 +116,9 @@ policy: bias_activation_fusion: True defer_fp32_logits: False moe_per_layer_logging: False + moe_enable_deepep: false + moe_token_dispatcher_type: "allgather" + moe_shared_expert_overlap: false optimizer: optimizer: "adam" diff --git a/examples/configs/grpo_math_1B_megatron.yaml b/examples/configs/grpo_math_1B_megatron.yaml index 1a14b8ce64..2b7d4473b4 100644 --- a/examples/configs/grpo_math_1B_megatron.yaml +++ b/examples/configs/grpo_math_1B_megatron.yaml @@ -94,6 +94,9 @@ policy: moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo moe_permute_fusion: false + moe_enable_deepep: false + moe_token_dispatcher_type: "allgather" + moe_shared_expert_overlap: false #gives ~20% training perf speedup with sequence packing apply_rope_fusion: True diff --git a/examples/configs/recipes/llm/grpo-dapomath17k-dsv3-megatron.yaml b/examples/configs/recipes/llm/grpo-dapomath17k-dsv3-megatron.yaml index 6e00ecd37c..bc636d931f 100644 --- a/examples/configs/recipes/llm/grpo-dapomath17k-dsv3-megatron.yaml +++ b/examples/configs/recipes/llm/grpo-dapomath17k-dsv3-megatron.yaml @@ -29,6 +29,8 @@ policy: sequence_parallel: true moe_permute_fusion: true apply_rope_fusion: false + moe_enable_deepep: true + moe_token_dispatcher_type: flex optimizer: lr: 5.0e-07 min_lr: 5.0e-08 diff --git a/examples/configs/recipes/llm/grpo-moonlight-16ba3b-4n8g-megatron.yaml b/examples/configs/recipes/llm/grpo-moonlight-16ba3b-4n8g-megatron.yaml index e1e38fbbfc..83ea6128ef 100644 --- a/examples/configs/recipes/llm/grpo-moonlight-16ba3b-4n8g-megatron.yaml +++ b/examples/configs/recipes/llm/grpo-moonlight-16ba3b-4n8g-megatron.yaml @@ -31,6 +31,8 @@ policy: lr: 1.0e-06 scheduler: lr_warmup_iters: 50 + moe_enable_deepep: true + moe_token_dispatcher_type: flex logger: monitor_gpus: false wandb: diff --git a/examples/configs/sft.yaml b/examples/configs/sft.yaml index 747564f422..486841cdc2 100644 --- a/examples/configs/sft.yaml +++ b/examples/configs/sft.yaml @@ -113,6 +113,9 @@ policy: bias_activation_fusion: True defer_fp32_logits: False moe_per_layer_logging: False + moe_enable_deepep: false + moe_token_dispatcher_type: "allgather" + moe_shared_expert_overlap: false peft: enabled: false diff --git a/examples/configs/sft_openmathinstruct2_megatron.yaml b/examples/configs/sft_openmathinstruct2_megatron.yaml index b0f94fff6d..0925a5c29a 100644 --- a/examples/configs/sft_openmathinstruct2_megatron.yaml +++ b/examples/configs/sft_openmathinstruct2_megatron.yaml @@ -92,6 +92,9 @@ policy: # gives ~25% training perf speedup with sequence packing and apply_rope_fusion bias_activation_fusion: True moe_per_layer_logging: False + moe_enable_deepep: false + moe_token_dispatcher_type: "allgather" + moe_shared_expert_overlap: false env_vars: PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:False" diff --git a/examples/configs/vlm_grpo_3B.yaml b/examples/configs/vlm_grpo_3B.yaml index 47233d87db..ec2f8531c0 100644 --- a/examples/configs/vlm_grpo_3B.yaml +++ b/examples/configs/vlm_grpo_3B.yaml @@ -104,6 +104,9 @@ policy: bias_activation_fusion: True defer_fp32_logits: False moe_per_layer_logging: False + moe_enable_deepep: false + moe_token_dispatcher_type: "allgather" + moe_shared_expert_overlap: false optimizer: optimizer: "adam" diff --git a/examples/configs/vlm_grpo_3B_megatron.yaml b/examples/configs/vlm_grpo_3B_megatron.yaml index 64f8ea158d..7ae6f38c67 100644 --- a/examples/configs/vlm_grpo_3B_megatron.yaml +++ b/examples/configs/vlm_grpo_3B_megatron.yaml @@ -146,6 +146,9 @@ policy: bias_activation_fusion: True defer_fp32_logits: False moe_per_layer_logging: False + moe_enable_deepep: false + moe_token_dispatcher_type: "allgather" + moe_shared_expert_overlap: false optimizer: optimizer: adam lr: 2.0e-07 diff --git a/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml b/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml index c28d958cdc..ae3740850d 100644 --- a/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml +++ b/examples/nemo_gym/grpo_workplace_assistant_nemotron_nano_v2_9b.yaml @@ -100,6 +100,9 @@ policy: apply_rope_fusion: True defer_fp32_logits: false moe_permute_fusion: false + moe_enable_deepep: false + moe_token_dispatcher_type: "allgather" + moe_shared_expert_overlap: false optimizer: optimizer: "adam" diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index 1a934a26d4..363399cbca 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -183,6 +183,16 @@ class MegatronConfig(TypedDict): # Force overwrite of the initial checkpoint even if it exists (default: False) force_overwrite_initial_ckpt: NotRequired[bool] moe_per_layer_logging: bool + # Set to true to enable DeepEP for expert parallel communication + # Must set moe_token_dispatcher_type to 'flex' + # Must set moe_shared_expert_overlap to False + moe_enable_deepep: bool + # The type of token dispatcher to use. The default is 'allgather'. + # Options are 'allgather','alltoall' and 'flex' + # Use 'flex' when using DeepEP + moe_token_dispatcher_type: str + # Can be used only with 'alltoall' token dispatcher + moe_shared_expert_overlap: bool optimizer: MegatronOptimizerConfig scheduler: MegatronSchedulerConfig distributed_data_parallel_config: MegatronDDPConfig diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 1d175f35b2..63b6691f13 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -658,6 +658,13 @@ def __init__( model_cfg.moe_router_bias_update_rate = self.cfg["megatron_cfg"][ "moe_router_bias_update_rate" ] + model_cfg.moe_enable_deepep = self.cfg["megatron_cfg"]["moe_enable_deepep"] + model_cfg.moe_token_dispatcher_type = self.cfg["megatron_cfg"][ + "moe_token_dispatcher_type" + ] + model_cfg.moe_shared_expert_overlap = self.cfg["megatron_cfg"][ + "moe_shared_expert_overlap" + ] model_cfg.moe_permute_fusion = self.cfg["megatron_cfg"]["moe_permute_fusion"] if "layernorm_epsilon" in self.cfg["megatron_cfg"]: diff --git a/pyproject.toml b/pyproject.toml index 29e683fdbe..e336a3c4f4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -100,6 +100,7 @@ mcore = [ # https://github.com/NVIDIA/TransformerEngine/blob/v2.3/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L108 # https://github.com/facebookresearch/xformers/blob/8354497deb2c04c67fbb2e2ad911e86530da0e90/xformers/ops/fmha/flash.py#L76 "flash-attn==2.8.1", + "deep_ep @ git+https://github.com/deepseek-ai/DeepEP.git@bfded34800dfec415b71503f8205181de90b2480", # Remove this once https://github.com/NVIDIA-NeMo/RL/issues/501 resolved "vllm==0.11.2", ] diff --git a/tests/unit/models/generation/test_vllm_generation.py b/tests/unit/models/generation/test_vllm_generation.py index 1599b7e703..87730c8908 100644 --- a/tests/unit/models/generation/test_vllm_generation.py +++ b/tests/unit/models/generation/test_vllm_generation.py @@ -178,6 +178,9 @@ def get_basic_megatron_test_config( "moe_router_load_balancing_type": "none", "moe_router_bias_update_rate": 0.0, "moe_permute_fusion": False, + "moe_enable_deepep": False, + "moe_token_dispatcher_type": "allgather", + "moe_shared_expert_overlap": False, "apply_rope_fusion": True, "bias_activation_fusion": True, "moe_per_layer_logging": False, diff --git a/tests/unit/models/policy/test_megatron_worker.py b/tests/unit/models/policy/test_megatron_worker.py index 5f09460cfb..426c64a0d1 100644 --- a/tests/unit/models/policy/test_megatron_worker.py +++ b/tests/unit/models/policy/test_megatron_worker.py @@ -135,6 +135,9 @@ def create_megatron_test_config( "apply_rope_fusion": True, "bias_activation_fusion": True, "moe_per_layer_logging": False, + "moe_enable_deepep": False, + "moe_token_dispatcher_type": "allgather", + "moe_shared_expert_overlap": False, "defer_fp32_logits": defer_fp32_logits, "train_iters": 100, # Required for Megatron training "optimizer": { diff --git a/uv.lock b/uv.lock index e6b1c3fe30..19c1375141 100644 --- a/uv.lock +++ b/uv.lock @@ -3950,6 +3950,7 @@ fsdp = [ { name = "vllm" }, ] mcore = [ + { name = "deep-ep" }, { name = "flash-attn" }, { name = "megatron-bridge" }, { name = "megatron-core" }, @@ -4017,6 +4018,7 @@ requires-dist = [ { name = "datasets", specifier = ">=4.0.0" }, { name = "debugpy" }, { name = "deep-ep", marker = "extra == 'automodel'", git = "https://github.com/deepseek-ai/DeepEP.git?rev=bfded34800dfec415b71503f8205181de90b2480" }, + { name = "deep-ep", marker = "extra == 'mcore'", git = "https://github.com/deepseek-ai/DeepEP.git?rev=bfded34800dfec415b71503f8205181de90b2480" }, { name = "deep-ep", marker = "extra == 'vllm'", git = "https://github.com/deepseek-ai/DeepEP.git?rev=bfded34800dfec415b71503f8205181de90b2480" }, { name = "deep-gemm", marker = "extra == 'vllm'", git = "https://github.com/deepseek-ai/DeepGEMM.git?rev=7b6b5563b9d4c1ae07ffbce7f78ad3ac9204827c" }, { name = "flash-attn", marker = "extra == 'automodel'", specifier = "==2.8.1" }, @@ -4873,20 +4875,21 @@ wheels = [ [[package]] name = "perceptron" -version = "0.1.4" +version = "0.2.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "colorama" }, { name = "httpx", extra = ["http2"] }, { name = "numpy" }, { name = "pillow" }, + { name = "pydantic" }, { name = "rich" }, { name = "shellingham" }, { name = "typer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/30/60/85db2243d8b550823603d8f9c5845b0dd0f01074e9aabf0b2af0c4f52565/perceptron-0.1.4.tar.gz", hash = "sha256:62fd190efb74925e2cc33c0cd38761e19959be3bdb7b24fbf9e3386d6961f690", size = 78116, upload-time = "2025-11-12T20:00:28.024Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c6/ff/87efbc3988094e09eb29261d545c84cd0a21376daa997435f5566281e2d2/perceptron-0.2.0.tar.gz", hash = "sha256:369ff3078ba7ac9e3b5f30d9f75ff44d72991b64c94f93c5267e751552cab3f6", size = 87447, upload-time = "2026-01-14T23:42:37.713Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ef/17/b7cb1a10ebb0a9a4c9fbcd96a28b43d44e08a90f620bab07e644a658d2f1/perceptron-0.1.4-py3-none-any.whl", hash = "sha256:f490a6df6c15167e91e1a528601cae98ce99a30991cf792f9ef83ebc15d335c4", size = 57421, upload-time = "2025-11-12T20:00:26.395Z" }, + { url = "https://files.pythonhosted.org/packages/8b/83/983a6663a7814c0772eabdf3f2e616758abd50a244dfbd770785c9c2ab95/perceptron-0.2.0-py3-none-any.whl", hash = "sha256:7dc7713778b797f3cb013406eb507ae729ca360347dba8196e82361134a436e8", size = 61076, upload-time = "2026-01-14T23:42:36.525Z" }, ] [[package]]