Skip to content

feat: add EP#3632

Merged
winglian merged 15 commits into
mainfrom
feat/ep
May 21, 2026
Merged

feat: add EP#3632
winglian merged 15 commits into
mainfrom
feat/ep

Conversation

@NanoCode012
Copy link
Copy Markdown
Collaborator

@NanoCode012 NanoCode012 commented Apr 29, 2026

Description

Adds EP via DeepEP
Tested on 2xA100 EP, 4xA100 EP & EP/FSDP and individual bench.
To test:

  • EP (N=4)
  • EP+FSDP (N=4)
  • Hopper specific changes Lowlatency etc [not yet planned]
  • EP N-D compose (TP/CP) [not planned]
    Very experimental.

Motivation and Context

How has this been tested?

Sweep and E2E tests with loss match with atol and grad norm + correctness check single layer.

AI Usage Disclaimer

Claude heavily

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features

    • Expert Parallelism support enabling efficient distributed training of mixture-of-experts models across multiple ranks and hardware configurations
  • Documentation

    • Added Expert Parallelism training guide including configuration examples, usage patterns, compatibility information, and troubleshooting instructions
  • Bug Fixes

    • Fixed processor argument handling in multimodal chat batch collator for improved compatibility

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 29, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 9f9f6858-e736-4d9f-b906-c6cc3d756418

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

Use the checkbox below for a quick retry:

  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This pull request introduces Expert Parallelism (EP) support via DeepEP integration, enabling efficient distributed training for Mixture-of-Experts models. The implementation adds a complete integration package with plugin architecture, expert sharding, buffer management, kernel registration, and mesh composition with FSDP. Changes span documentation, configuration examples, core integration modules, parallelism infrastructure modifications, and comprehensive tests.

Changes

Cohort / File(s) Summary
Documentation & Examples
docs/nd_parallelism.qmd, examples/expert_parallel/ep_30b_fullft_4gpu.yaml, examples/expert_parallel/ep_fsdp_30b_fullft_4gpu.yaml, src/axolotl/integrations/expert_parallel/README.md
Adds EP section to parallelism documentation with schema updates, support matrix, and mesh composition constraints. Provides two example configurations: pure EP (4 expert ranks) and EP+FSDP (2D composition). README details hardware requirements, installation, and DeepEP usage patterns.
Core EP Integration Package
src/axolotl/integrations/expert_parallel/__init__.py, src/axolotl/integrations/expert_parallel/args.py, src/axolotl/integrations/expert_parallel/buffer.py
Initializes EP integration with ExpertParallelArgs configuration model (validates expert_parallel_size, RDMA warnings), and lazy-loaded Buffer singleton for DeepEP runtime state.
Expert Kernel Registration & Dispatch
src/axolotl/integrations/expert_parallel/experts_fn.py
Implements four DeepEP forward entrypoints (deep_ep_experts_forward, deep_ep_grouped_mm_experts_forward, etc.) with autograd dispatch/combine via DeepEP buffers, kernel selection logic, and transformers integration registry. High-density autograd and kernel orchestration logic.
Expert Parallelism Plugin
src/axolotl/integrations/expert_parallel/plugin.py
Core plugin lifecycle: registers kernels, infers local kernel variants, validates mesh axes, shards expert weights, configures buffers, manages EP process groups (pure EP and EP×FSDP), and registers gradient scaling hooks. Handles FSDP pre-sharding and DDP parameter exclusion propagation.
Expert Weight Sharding
src/axolotl/integrations/expert_parallel/shard.py
Slices expert module weights (gate_up_proj, down_proj) along experts dimension per EP rank, updates module metadata, and updates DDP ignore lists. Includes cleanup and logging.
Parallelism Infrastructure
src/axolotl/monkeypatch/accelerate/fsdp2.py, src/axolotl/monkeypatch/accelerate/parallelism_config.py, src/axolotl/utils/distributed.py, src/axolotl/utils/trainer.py
Extends Accelerate parallelism config to include EP as a mesh dimension, reorders device-mesh construction (ep, dp_replicate, dp_shard, cp, sp, tp), adds EP-aware gradient norm clipping and data loader preparation, integrates EP sizing into distributed config builders.
Comprehensive Integration Tests
tests/integrations/test_expert_parallel.py
Tests ExpertParallelArgs validation, kernel inference and registration, expert module detection via Qwen3-MoE, weight sharding idempotency, plugin lifecycle, and distributed EP+FSDP mesh topology with gloo spawning.
Collator Fix
src/axolotl/utils/collators/mm_chat.py
Updates process_rows to pass padding via processor_kwargs instead of direct argument.

Sequence Diagram

sequenceDiagram
    participant User
    participant Plugin as ExpertParallelPlugin
    participant DeepEP as DeepEP<br/>(Buffer)
    participant Model as MoE Model<br/>(Experts)
    participant Output

    User->>Plugin: configure expert_parallel_size > 1
    Plugin->>Plugin: pre_model_load: register kernels
    Plugin->>Model: post_model_build: shard experts
    Plugin->>DeepEP: configure_buffer(ep_group, nvl_bytes)
    
    activate Model
    Model->>DeepEP: forward: dispatch(hidden_states, routing)
    Note over DeepEP: Send tokens to<br/>correct expert ranks
    
    DeepEP->>Model: receive dispatched tokens
    Model->>Model: run local experts<br/>(grouped_mm, scattermoe, etc.)
    Model->>DeepEP: combine(expert_outputs)
    Note over DeepEP: Reconstruct per-token<br/>outputs across ranks
    DeepEP->>Output: per-token expert output
    deactivate Model
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

This PR introduces substantial logic across multiple domains: distributed expert sharding (166 lines), high-density kernel registration and dispatch (277 lines), complex plugin lifecycle management (361 lines), and deep Accelerate monkeypatching (279 lines). The heterogeneous nature—spanning buffer management, autograd integration, mesh composition, and DDP/FSDP interop—requires careful reasoning for each subsystem and understanding of distributed training semantics.

Possibly related PRs

  • #3410: Implements ScatterMoE and SonicMoE kernel modules that are integrated and registered as DeepEP expert implementations in this PR.
  • #3019: Modifies Accelerate parallelism config and device-mesh construction similarly; both PRs extend mesh axis handling and coordinate multiple parallelism dimensions.
  • #2977: ND-Parallel and context-parallel refactor that similarly extends parallelism infrastructure, FSDP integration, and accelerate monkeypatching for new axes.

Suggested reviewers

  • SalmanMohammadi
  • winglian
  • djsaunde

Poem

🐰 A rabbit hops through expert ranks so fine,
Each specialist now trains in parallel line,
DeepEP kernels dispatch and combine with grace,
No single expert bears the burden's race!
Hopping faster, sharding wide—the MoE takes flight!

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 35.35% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title 'feat: add EP' is extremely vague and does not clearly convey the main change, using a non-descriptive abbreviation 'EP' without context that would be meaningful to someone scanning commit history. Use a more descriptive title like 'feat: add expert parallel (EP) integration with DeepEP' or 'feat: add expert-parallel (EP) distributed training support' to clearly communicate the feature being added.
✅ Passed checks (3 passed)
Check name Status Explanation
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch feat/ep

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@NanoCode012 NanoCode012 marked this pull request as ready for review April 30, 2026 13:26
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 30, 2026

📖 Documentation Preview: https://6a0eb2f0a659e418c102248d--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit 7e2fbde

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (7)
DEEPEP_SETUP.md (1)

87-90: 💤 Low value

Add language specifier to fenced code block.

The error message code block is missing a language identifier. While this is a minor linting issue, adding a specifier improves rendering consistency.

📝 Suggested fix
-```
+```text
 ImportError: deep_ep_cpp.cpython-...-x86_64-linux-gnu.so: undefined symbol:
 _ZN7deep_ep12internode_ll17query_mask_bufferEPiiS1_P11CUstream_st
</details>

<details>
<summary>🤖 Prompt for AI Agents</summary>

Verify each finding against the current code and only fix it if needed.

In @DEEPEP_SETUP.md around lines 87 - 90, The fenced code block in
DEEPEP_SETUP.md containing the ImportError message lacks a language specifier;
update that block to include a language identifier (e.g., "text") so the block
reads text ... to improve linting and rendering consistency—locate the
ImportError snippet in DEEPEP_SETUP.md and add the "text" specifier to the
opening backticks around the block containing the
_ZN7deep_ep12internode_ll17query_mask_buffer... symbol.


</details>

</blockquote></details>
<details>
<summary>src/axolotl/integrations/expert_parallel/buffer.py (1)</summary><blockquote>

`59-62`: _💤 Low value_

**`reset_buffer()` doesn't reset configuration parameters.**

When `reset_buffer()` is called (e.g., between tests), only `_BUFFER` is cleared but `_EP_GROUP`, `_NUM_NVL_BYTES`, and `_NUM_RDMA_BYTES` retain their previous values. If a subsequent test calls `get_buffer()` without first calling `configure_buffer()`, it will use stale configuration from the previous test.



<details>
<summary>💡 Consider resetting all state for better test isolation</summary>

```diff
 def reset_buffer() -> None:
     """Drop the cached Buffer. Used in tests."""
-    global _BUFFER
+    global _BUFFER, _EP_GROUP, _NUM_NVL_BYTES, _NUM_RDMA_BYTES
     _BUFFER = None
+    _EP_GROUP = None
+    _NUM_NVL_BYTES = 256 << 20
+    _NUM_RDMA_BYTES = 0
```
</details>

<details>
<summary>🤖 Prompt for AI Agents</summary>

```
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/expert_parallel/buffer.py` around lines 59 - 62,
reset_buffer() currently only clears the cached _BUFFER leaving _EP_GROUP,
_NUM_NVL_BYTES, and _NUM_RDMA_BYTES with stale values; update reset_buffer to
also reset those globals to their initial/default state (e.g., None or zero as
used by configure_buffer/get_buffer) so tests calling get_buffer() without
reconfiguring won't see previous configuration; reference the symbols _EP_GROUP,
_NUM_NVL_BYTES, _NUM_RDMA_BYTES, reset_buffer(), configure_buffer(), and
get_buffer() when making the change.
```

</details>

</blockquote></details>
<details>
<summary>src/axolotl/utils/distributed.py (1)</summary><blockquote>

`332-340`: _💤 Low value_

**EP world size handling is correct but relies on separate env var/monkeypatch.**

The EP size is correctly validated and divided from `remaining_world_size`, but unlike TP/CP/DP, it's not added to `pc_kwargs`. This is intentional since EP mesh handling is done via `PARALLELISM_CONFIG_EP_SIZE` env var and `patch_parallelism_config()` in `trainer.py`.

This two-path approach (standard config for TP/CP/DP, monkeypatch for EP) works but may be worth documenting in a code comment for future maintainers.



<details>
<summary>💡 Consider adding a clarifying comment</summary>

```diff
     # EP consumes part of world_size; subtract it up front so the auto-fill
     # below doesn't put EP ranks into `dp_replicate_size`.
+    # Note: ep_size is NOT added to pc_kwargs; it's handled separately via
+    # PARALLELISM_CONFIG_EP_SIZE env var and patch_parallelism_config().
     if expert_parallel_size and expert_parallel_size > 1:
```
</details>

<details>
<summary>🤖 Prompt for AI Agents</summary>

```
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/utils/distributed.py` around lines 332 - 340, The EP handling
block validates and divides remaining_world_size using expert_parallel_size but
does not add EP size to pc_kwargs because EP is configured via the
PARALLELISM_CONFIG_EP_SIZE env var and patch_parallelism_config() in trainer.py;
add a brief clarifying comment next to the
expert_parallel_size/remaining_world_size block that states that EP is
intentionally excluded from pc_kwargs, references PARALLELISM_CONFIG_EP_SIZE and
patch_parallelism_config(), and explains that EP mesh is applied via
environment/monkeypatch elsewhere so future maintainers won't try to add it to
pc_kwargs.
```

</details>

</blockquote></details>
<details>
<summary>docs/nd_parallelism.qmd (1)</summary><blockquote>

`104-106`: _💤 Low value_

**Example 4 is helpful but could clarify plugin configuration.**

The example mentions adding `ExpertParallelPlugin` to `plugins:` but doesn't show the exact YAML syntax. Consider adding the full plugin path for clarity.



<details>
<summary>📝 Consider expanding the example</summary>

```diff
 4.  FSDP + EP on a 4-GPU MoE training run:
     - You want EP to shard the experts and FSDP to shard non-expert params on orthogonal mesh axes.
-    - Set `expert_parallel_size: 2` and `dp_shard_size: 2` (`ep × dp_shard == world_size`). Add the `ExpertParallelPlugin` to `plugins:`.
+    - Set `expert_parallel_size: 2` and `dp_shard_size: 2` (`ep × dp_shard == world_size`). Add the plugin:
+      ```yaml
+      plugins:
+        - axolotl.integrations.expert_parallel.ExpertParallelPlugin
+      ```
```
</details>

<details>
<summary>🤖 Prompt for AI Agents</summary>

```
Verify each finding against the current code and only fix it if needed.

In `@docs/nd_parallelism.qmd` around lines 104 - 106, Clarify the plugin
configuration by showing the fully-qualified plugin path in the example: when
using expert_parallel_size and dp_shard_size with ExpertParallelPlugin, update
the plugins block to include the full class path (e.g.,
axolotl.integrations.expert_parallel.ExpertParallelPlugin) so readers can
copy-paste the exact YAML; ensure the example references ExpertParallelPlugin,
expert_parallel_size, dp_shard_size, and plugins together for clarity.
```

</details>

</blockquote></details>
<details>
<summary>src/axolotl/integrations/expert_parallel/README.md (1)</summary><blockquote>

`113-113`: _💤 Low value_

**Fix heading level increment.**

The heading jumps from `##` (h2) to `####` (h4), skipping h3. This breaks accessibility and consistent document structure per markdown best practices.


<details>
<summary>Proposed fix</summary>

```diff
-#### Implementation notes
+### Implementation notes
```
</details>

<details>
<summary>🤖 Prompt for AI Agents</summary>

```
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/expert_parallel/README.md` at line 113, The
"Implementation notes" heading currently uses h4 (#### Implementation notes)
which skips h3 and breaks document structure; change that heading to h3 (###
Implementation notes) so the sequence follows the preceding h2 and maintains
proper accessibility and semantic hierarchy in the README.
```

</details>

</blockquote></details>
<details>
<summary>tests/integrations/test_expert_parallel.py (1)</summary><blockquote>

`148-164`: _💤 Low value_

**Good test coverage for expert module detection.**

Tests correctly verify that 3D expert tensors are detected while non-3D modules are skipped. Minor: the unpacked `name` variable on line 153 is unused.


<details>
<summary>Silence unused variable warning</summary>

```diff
-        name, module = found[0]
+        _name, module = found[0]
```
</details>

<details>
<summary>🤖 Prompt for AI Agents</summary>

```
Verify each finding against the current code and only fix it if needed.

In `@tests/integrations/test_expert_parallel.py` around lines 148 - 164, The test
unpacks a (name, module) tuple but never uses name in
TestExpertModuleDetection.test_detects_qwen3moe_experts; update the unpacking to
avoid an unused-variable warning by replacing "name, module = found[0]" with a
discard (e.g., "_, module = found[0]") or by directly accessing the module
(e.g., "module = found[0][1]") where the tuple comes from
_detect_experts_modules; keep the rest of the assertions unchanged.
```

</details>

</blockquote></details>
<details>
<summary>src/axolotl/integrations/expert_parallel/plugin.py (1)</summary><blockquote>

`120-128`: _💤 Low value_

**Suffix matching may over-match parameter names.**

The condition `full.endswith(short_name)` without a dot prefix could match unrelated parameters. For example, if `short_name="proj"`, it would match both `gate_up_proj` and `down_proj`, which may be intended but could also match `some_other_proj` unexpectedly.

However, since `ignore_list` comes from sharded expert modules and typically contains full submodule paths like `experts.gate_up_proj`, this is likely safe in practice.

<details>
<summary>🤖 Prompt for AI Agents</summary>

```
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/expert_parallel/plugin.py` around lines 120 - 128,
Suffix-only matching can over-match unrelated parameter names; in the loop over
ignore_list/all_names in plugin.py (variables: ignore_list, all_names, resolved,
short_name, full) replace the broad full.endswith(short_name) check with a safer
check that compares the final path segment (e.g., full.rsplit(".", 1)[-1] ==
short_name) or requires a dot prefix (full.endswith("." + short_name)); update
the conditional so matches are only when full equals short_name, ends with "." +
short_name, or the final segment equals short_name, then append to resolved as
before.
```

</details>

</blockquote></details>

</blockquote></details>

<details>
<summary>🤖 Prompt for all review comments with AI agents</summary>

Verify each finding against the current code and only fix it if needed.

Inline comments:
In @examples/expert_parallel/ep_30b_fullft_4gpu.yaml:

  • Line 26: The config uses warmup_steps: 0.1 which is invalid because
    warmup_steps expects an integer; replace the key warmup_steps with warmup_ratio
    (i.e., change warmup_steps -> warmup_ratio) so the float 0.1 is treated as a
    ratio, or alternatively set warmup_steps to an integer if you really want a step
    count; update any references expecting warmup_steps accordingly (look for
    warmup_steps and warmup_ratio in the config or loader).

In @examples/expert_parallel/ep_fsdp_30b_fullft_4gpu.yaml:

  • Line 28: The YAML sets warmup_steps: 0.1 which is a misconfiguration because
    warmup_steps expects an integer; change the key to warmup_ratio and keep the 0.1
    value (i.e., replace warmup_steps with warmup_ratio) so the scheduler reads a
    ratio instead of an integer—update the entry for warmup_steps to warmup_ratio in
    the config (referencing the warmup_steps/warmup_ratio keys).

In @src/axolotl/monkeypatch/accelerate/parallelism_config.py:

  • Around line 85-92: The _patched_get_mesh function can raise ValueError when
    mesh_order.index(x[0]) is called for an unknown dimension name; update
    _patched_get_mesh to validate or handle unexpected names by filtering or mapping
    active_mesh_dims against the allowed mesh_order before sorting (e.g., build
    mesh_dims from self.active_mesh_dims but only include keys present in
    mesh_order, or assign a fallback index for unknown keys), and ensure the
    returned sorted_items (used to produce the tuple) only contains known dimensions
    so zip(*sorted_items, strict=True) stays safe; reference the symbols
    _patched_get_mesh, mesh_order, mesh_dims, active_mesh_dims, and sorted_items
    when making the change.

Nitpick comments:
In @DEEPEP_SETUP.md:

  • Around line 87-90: The fenced code block in DEEPEP_SETUP.md containing the
    ImportError message lacks a language specifier; update that block to include a
    language identifier (e.g., "text") so the block reads text ... to improve
    linting and rendering consistency—locate the ImportError snippet in
    DEEPEP_SETUP.md and add the "text" specifier to the opening backticks around the
    block containing the _ZN7deep_ep12internode_ll17query_mask_buffer... symbol.

In @docs/nd_parallelism.qmd:

  • Around line 104-106: Clarify the plugin configuration by showing the
    fully-qualified plugin path in the example: when using expert_parallel_size and
    dp_shard_size with ExpertParallelPlugin, update the plugins block to include the
    full class path (e.g.,
    axolotl.integrations.expert_parallel.ExpertParallelPlugin) so readers can
    copy-paste the exact YAML; ensure the example references ExpertParallelPlugin,
    expert_parallel_size, dp_shard_size, and plugins together for clarity.

In @src/axolotl/integrations/expert_parallel/buffer.py:

  • Around line 59-62: reset_buffer() currently only clears the cached _BUFFER
    leaving _EP_GROUP, _NUM_NVL_BYTES, and _NUM_RDMA_BYTES with stale values; update
    reset_buffer to also reset those globals to their initial/default state (e.g.,
    None or zero as used by configure_buffer/get_buffer) so tests calling
    get_buffer() without reconfiguring won't see previous configuration; reference
    the symbols _EP_GROUP, _NUM_NVL_BYTES, _NUM_RDMA_BYTES, reset_buffer(),
    configure_buffer(), and get_buffer() when making the change.

In @src/axolotl/integrations/expert_parallel/plugin.py:

  • Around line 120-128: Suffix-only matching can over-match unrelated parameter
    names; in the loop over ignore_list/all_names in plugin.py (variables:
    ignore_list, all_names, resolved, short_name, full) replace the broad
    full.endswith(short_name) check with a safer check that compares the final path
    segment (e.g., full.rsplit(".", 1)[-1] == short_name) or requires a dot prefix
    (full.endswith("." + short_name)); update the conditional so matches are only
    when full equals short_name, ends with "." + short_name, or the final segment
    equals short_name, then append to resolved as before.

In @src/axolotl/integrations/expert_parallel/README.md:

  • Line 113: The "Implementation notes" heading currently uses h4 (####
    Implementation notes) which skips h3 and breaks document structure; change that
    heading to h3 (### Implementation notes) so the sequence follows the preceding
    h2 and maintains proper accessibility and semantic hierarchy in the README.

In @src/axolotl/utils/distributed.py:

  • Around line 332-340: The EP handling block validates and divides
    remaining_world_size using expert_parallel_size but does not add EP size to
    pc_kwargs because EP is configured via the PARALLELISM_CONFIG_EP_SIZE env var
    and patch_parallelism_config() in trainer.py; add a brief clarifying comment
    next to the expert_parallel_size/remaining_world_size block that states that EP
    is intentionally excluded from pc_kwargs, references PARALLELISM_CONFIG_EP_SIZE
    and patch_parallelism_config(), and explains that EP mesh is applied via
    environment/monkeypatch elsewhere so future maintainers won't try to add it to
    pc_kwargs.

In @tests/integrations/test_expert_parallel.py:

  • Around line 148-164: The test unpacks a (name, module) tuple but never uses
    name in TestExpertModuleDetection.test_detects_qwen3moe_experts; update the
    unpacking to avoid an unused-variable warning by replacing "name, module =
    found[0]" with a discard (e.g., "_, module = found[0]") or by directly accessing
    the module (e.g., "module = found[0][1]") where the tuple comes from
    _detect_experts_modules; keep the rest of the assertions unchanged.

</details>

<details>
<summary>🪄 Autofix (Beta)</summary>

Fix all unresolved CodeRabbit comments on this PR:

- [ ] <!-- {"checkboxId": "4b0d0e0a-96d7-4f10-b296-3a18ea78f0b9"} --> Push a commit to this branch (recommended)
- [ ] <!-- {"checkboxId": "ff5b1114-7d8c-49e6-8ac1-43f82af23a33"} --> Create a new PR with the fixes

</details>

---

<details>
<summary>ℹ️ Review info</summary>

<details>
<summary>⚙️ Run configuration</summary>

**Configuration used**: Path: .coderabbit.yaml

**Review profile**: CHILL

**Plan**: Pro

**Run ID**: `4d41e62d-ddb9-4683-8e93-7d88aae13904`

</details>

<details>
<summary>📥 Commits</summary>

Reviewing files that changed from the base of the PR and between ac77da96daf074475fc8f207b179510ed7ba3f17 and ca82a1a028e607ce4f1c1adecd67b28b93da9688.

</details>

<details>
<summary>📒 Files selected for processing (16)</summary>

* `DEEPEP_SETUP.md`
* `docs/nd_parallelism.qmd`
* `examples/expert_parallel/ep_30b_fullft_4gpu.yaml`
* `examples/expert_parallel/ep_fsdp_30b_fullft_4gpu.yaml`
* `src/axolotl/integrations/expert_parallel/README.md`
* `src/axolotl/integrations/expert_parallel/__init__.py`
* `src/axolotl/integrations/expert_parallel/args.py`
* `src/axolotl/integrations/expert_parallel/buffer.py`
* `src/axolotl/integrations/expert_parallel/experts_fn.py`
* `src/axolotl/integrations/expert_parallel/plugin.py`
* `src/axolotl/integrations/expert_parallel/shard.py`
* `src/axolotl/monkeypatch/accelerate/fsdp2.py`
* `src/axolotl/monkeypatch/accelerate/parallelism_config.py`
* `src/axolotl/utils/distributed.py`
* `src/axolotl/utils/trainer.py`
* `tests/integrations/test_expert_parallel.py`

</details>

</details>

<!-- This is an auto-generated comment by CodeRabbit for review status -->

optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.003
warmup_steps: 0.1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Likely misconfiguration: warmup_steps should be warmup_ratio.

warmup_steps expects an integer (number of steps), but the value 0.1 suggests a ratio was intended. This will likely cause unexpected behavior or an error.

Proposed fix
-warmup_steps: 0.1
+warmup_ratio: 0.1
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
warmup_steps: 0.1
warmup_ratio: 0.1
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/expert_parallel/ep_30b_fullft_4gpu.yaml` at line 26, The config uses
warmup_steps: 0.1 which is invalid because warmup_steps expects an integer;
replace the key warmup_steps with warmup_ratio (i.e., change warmup_steps ->
warmup_ratio) so the float 0.1 is treated as a ratio, or alternatively set
warmup_steps to an integer if you really want a step count; update any
references expecting warmup_steps accordingly (look for warmup_steps and
warmup_ratio in the config or loader).

optimizer: adamw_torch_8bit
lr_scheduler: cosine
learning_rate: 0.003
warmup_steps: 0.1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Same misconfiguration: warmup_steps should be warmup_ratio.

Same issue as in ep_30b_fullft_4gpu.yamlwarmup_steps expects an integer, not a float ratio.

Proposed fix
-warmup_steps: 0.1
+warmup_ratio: 0.1
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
warmup_steps: 0.1
warmup_ratio: 0.1
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/expert_parallel/ep_fsdp_30b_fullft_4gpu.yaml` at line 28, The YAML
sets warmup_steps: 0.1 which is a misconfiguration because warmup_steps expects
an integer; change the key to warmup_ratio and keep the 0.1 value (i.e., replace
warmup_steps with warmup_ratio) so the scheduler reads a ratio instead of an
integer—update the entry for warmup_steps to warmup_ratio in the config
(referencing the warmup_steps/warmup_ratio keys).

Comment on lines +85 to +92
def _patched_get_mesh(self):
"""Build (dim_names, shape) for `init_device_mesh`. Order keeps the dp
block (ep, dp_replicate, dp_shard) contiguous so `_flatten("dp")` works.
"""
mesh_dims = {p: self._sizes[p] for p in self.active_mesh_dims}
mesh_order = ["ep", "dp_replicate", "dp_shard", "cp", "sp", "tp"]
sorted_items = sorted(mesh_dims.items(), key=lambda x: mesh_order.index(x[0]))
return tuple(zip(*sorted_items, strict=True))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Potential ValueError if mesh contains unexpected dimension names.

mesh_order.index(x[0]) will raise ValueError if active_mesh_dims contains a dimension not in mesh_order. Consider adding a fallback or validation.

Proposed defensive fix
 def _patched_get_mesh(self):
     """Build (dim_names, shape) for `init_device_mesh`. Order keeps the dp
     block (ep, dp_replicate, dp_shard) contiguous so `_flatten("dp")` works.
     """
     mesh_dims = {p: self._sizes[p] for p in self.active_mesh_dims}
     mesh_order = ["ep", "dp_replicate", "dp_shard", "cp", "sp", "tp"]
-    sorted_items = sorted(mesh_dims.items(), key=lambda x: mesh_order.index(x[0]))
+    def order_key(item):
+        try:
+            return mesh_order.index(item[0])
+        except ValueError:
+            return len(mesh_order)  # unknown dims go last
+    sorted_items = sorted(mesh_dims.items(), key=order_key)
     return tuple(zip(*sorted_items, strict=True))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/monkeypatch/accelerate/parallelism_config.py` around lines 85 -
92, The _patched_get_mesh function can raise ValueError when
mesh_order.index(x[0]) is called for an unknown dimension name; update
_patched_get_mesh to validate or handle unexpected names by filtering or mapping
active_mesh_dims against the allowed mesh_order before sorting (e.g., build
mesh_dims from self.active_mesh_dims but only include keys present in
mesh_order, or assign a fallback index for unknown keys), and ensure the
returned sorted_items (used to produce the tuple) only contains known dimensions
so zip(*sorted_items, strict=True) stays safe; reference the symbols
_patched_get_mesh, mesh_order, mesh_dims, active_mesh_dims, and sorted_items
when making the change.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
src/axolotl/integrations/expert_parallel/plugin.py (2)

229-229: 💤 Low value

Use ASCII x instead of Unicode × in error messages.

The multiplication sign × (Unicode) can cause display issues in some terminals and logging systems.

♻️ Suggested fix
-                    f"ep={ep_size}, dp_shard={dp_shard_size}, tp={tp_size}, cp={cp_size}. "
-                    "v1 supports only EP-only or EP × dp_shard."
+                    f"ep={ep_size}, dp_shard={dp_shard_size}, tp={tp_size}, cp={cp_size}. "
+                    "v1 supports only EP-only or EP x dp_shard."

Also on line 252:

-            "Set dp_shard_size such that ep × dp_shard == world_size, or set "
+            "Set dp_shard_size such that ep x dp_shard == world_size, or set "
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/expert_parallel/plugin.py` at line 229, Replace the
Unicode multiplication sign '×' with the ASCII letter 'x' in the error message
string "v1 supports only EP-only or EP × dp_shard." (and the other similar
message referenced at the other location) so it reads "v1 supports only EP-only
or EP x dp_shard."; update both occurrences to avoid terminals/logging display
issues.

118-126: 💤 Low value

Suffix matching logic could match unintended parameters.

The condition on line 124 (full.endswith(short_name)) matches without requiring a dot separator, which could incorrectly match parameters that share a suffix. For example, if short_name = "proj", it would match "gate_up_proj".

In practice, short_name values from shard.py are fully qualified paths (e.g., "layers.0.experts.gate_up_proj"), so false positives are unlikely, but the matching could be tightened:

♻️ Suggested fix to tighten matching
         for short_name in ignore_list:
             # Match either an exact suffix or with PEFT's `base_model.model.` prefix.
             for full in all_names:
-                if (
-                    full == short_name
-                    or full.endswith("." + short_name)
-                    or full.endswith(short_name)
-                ):
+                if full == short_name or full.endswith("." + short_name):
                     resolved.append(full)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/integrations/expert_parallel/plugin.py` around lines 118 - 126,
The matching in the loop using full, short_name, ignore_list, all_names, and
resolved is too permissive because full.endswith(short_name) can match
unintended suffixes; tighten it so you only accept exact matches or suffixes
preceded by a dot (or the PEFT prefix case already handled), e.g., replace the
loose full.endswith(short_name) check with a check that requires a dot separator
before short_name (or use a regex like (^|\\.)short_name$) so only true
path-segment suffixes are resolved into resolved.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@src/axolotl/integrations/expert_parallel/plugin.py`:
- Line 229: Replace the Unicode multiplication sign '×' with the ASCII letter
'x' in the error message string "v1 supports only EP-only or EP × dp_shard."
(and the other similar message referenced at the other location) so it reads "v1
supports only EP-only or EP x dp_shard."; update both occurrences to avoid
terminals/logging display issues.
- Around line 118-126: The matching in the loop using full, short_name,
ignore_list, all_names, and resolved is too permissive because
full.endswith(short_name) can match unintended suffixes; tighten it so you only
accept exact matches or suffixes preceded by a dot (or the PEFT prefix case
already handled), e.g., replace the loose full.endswith(short_name) check with a
check that requires a dot separator before short_name (or use a regex like
(^|\\.)short_name$) so only true path-segment suffixes are resolved into
resolved.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 08a0283e-778d-424f-8b47-587a914046a1

📥 Commits

Reviewing files that changed from the base of the PR and between ca82a1a and f9a23e5.

📒 Files selected for processing (5)
  • src/axolotl/integrations/expert_parallel/__init__.py
  • src/axolotl/integrations/expert_parallel/args.py
  • src/axolotl/integrations/expert_parallel/plugin.py
  • src/axolotl/integrations/expert_parallel/shard.py
  • src/axolotl/utils/collators/mm_chat.py
✅ Files skipped from review due to trivial changes (1)
  • src/axolotl/integrations/expert_parallel/init.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/axolotl/integrations/expert_parallel/args.py

@winglian winglian merged commit cc25d3e into main May 21, 2026
20 checks passed
@winglian winglian deleted the feat/ep branch May 21, 2026 13:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants