Skip to content

fix use of flex attn kwargs and add support for tiledmlp for GPT-OSS#3116

Merged
winglian merged 3 commits into
mainfrom
flex-attn-kwargs
Aug 29, 2025
Merged

fix use of flex attn kwargs and add support for tiledmlp for GPT-OSS#3116
winglian merged 3 commits into
mainfrom
flex-attn-kwargs

Conversation

@winglian

@winglian winglian commented Aug 29, 2025

Copy link
Copy Markdown
Collaborator

Description

Motivation and Context

How has this been tested?

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

  • New Features
    • Added Mixture-of-Experts support to tiled MLP for gpt_oss models, enabling distributed MoE training/inference.
  • Bug Fixes
    • Improved FlexAttention stability and compatibility across PyTorch 2.5.1 and 2.6.0 via version-aware compilation to reduce compile/runtime issues.
    • Fixed gradient handling for tiled MLP with sharded execution, ensuring correct accumulation across shards and proper support for tuple outputs.

@coderabbitai

coderabbitai Bot commented Aug 29, 2025

Copy link
Copy Markdown
Contributor

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

Introduces version-aware flex attention compilation and removes mask patching. Adds DeepSpeed-compatible tiled MoE autograd path with tuple-aware shard handling and gradient accumulation. Updates tiled MLP patching to select MoE-specific implementation for gpt_oss models. Adjusts patch manager to pass compile kwargs to flex wrapper only.

Changes

Cohort / File(s) Summary of changes
Patch manager: flex attention invocation
src/axolotl/loaders/patch_manager.py
Stops importing/using mask patch; builds flex_attn_compile_kwargs from config and calls patch_flex_wrapper only; retains sample_packing causal mask patch.
Flex attention: version-aware compile
src/axolotl/monkeypatch/attention/flex_attn.py
Adds torch/transformers version checks; compiles with dynamic/mode variants based on version and training; removes previous unconditional compile/logging; sets training state and compiled flag.
Tiled MLP core (MoE + grads)
src/axolotl/monkeypatch/tiled_mlp/base.py
Adds DeepSpeedTiledMLPMoE autograd.Function and GradientAccumulator; makes forward/backward tuple-aware; implements shard-wise grad accumulation and parameter hook management.
Tiled MLP patching (impl selection)
src/axolotl/monkeypatch/tiled_mlp/patch.py
Chooses DeepSpeedTiledMLPMoE for model_type "gpt_oss", otherwise DeepSpeedTiledMLP; imports new classes; non-distributed path unchanged.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

  • TiledMLP support #2865 — Touches the same tiled MLP and patch manager areas, introducing DeepSpeedTiledMLPMoE, tuple-aware handling, and gradient accumulation.

Suggested reviewers

  • djsaunde
  • SalmanMohammadi
✨ Finishing Touches
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch flex-attn-kwargs

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore or @coderabbit ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@coderabbitai coderabbitai Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
src/axolotl/monkeypatch/attention/flex_attn.py (2)

17-21: Unreachable version branches due to early return; clarify intent or remove the guard.

Returning early when torch is not 2.6 makes the 2.5.1 and fallback branches in WrappedFlexAttention.init unreachable. Either:

  • Keep the wrapper active for all versions and select behavior inside, or
  • Limit the wrapper to 2.6.x only and delete the <=2.5.1 path.

I recommend the first for consistency. Patch:

-    is_torch_2_6 = torch.__version__.startswith("2.6")
-
-    if not is_torch_2_6:
-        return
+    # Apply wrapper across versions; per-version behavior is handled below.

Optionally keep a narrow gate if you explicitly don't want to affect non-2.6: then also delete the is_torch_less_or_equal("2.5.1") branch to avoid dead code.

Also applies to: 51-68


203-210: Duplicate assignment; set once.

make_flex_block_causal_mask is assigned twice to the same target.

Apply:

-                sys.modules[
-                    n
-                ].make_flex_block_causal_mask = patched_make_flex_block_causal_mask
-                sys.modules[
-                    n
-                ].make_flex_block_causal_mask = patched_make_flex_block_causal_mask
+                sys.modules[n].make_flex_block_causal_mask = patched_make_flex_block_causal_mask
src/axolotl/loaders/patch_manager.py (1)

50-55: Flex attention patching is disabled.

_apply_flex_attention_patches() is commented out, so none of the new flex-attn kwargs/version logic runs. This contradicts the PR goal.

Apply:

-        # self._apply_flex_attention_patches()
+        self._apply_flex_attention_patches()
src/axolotl/monkeypatch/tiled_mlp/base.py (1)

191-257: GradientAccumulator installs hooks per shard; install once and gate final assignment.

Refactor to avoid N× accumulation:

 class GradientAccumulator:
@@
-    def install_hooks(self, is_last_shard: bool):
-        """Install gradient hooks that accumulate gradients in higher precision"""
+    def install_hooks(self):
+        """Install gradient hooks once; use self.is_last_shard to gate final assignment"""
+        self.is_last_shard = False
@@
-        def create_hook(param):
+        def create_hook(param):
             def hook(grad):
                 with self.lock:
                     grad_to_accum_dtype = grad.to(self.grad_accumulation_dtype)
                     scaled_grad = grad_to_accum_dtype * self.gradient_scale
@@
-                    # Only assign the averaged gradient on the last shard
-                    if is_last_shard:
+                    # Only assign the averaged gradient on the last shard
+                    if self.is_last_shard:
                         param.grad = self.accumulated_grads[param].to(param.dtype)
                         return param.grad
                     return None
@@
-        for param in self.params:
+        for param in self.params:
             if param.requires_grad:
                 hook = param.register_hook(create_hook(param))
                 self.hooks.append(hook)
+
+    def set_is_last_shard(self, is_last: bool):
+        self.is_last_shard = is_last
🧹 Nitpick comments (6)
src/axolotl/monkeypatch/attention/flex_attn.py (1)

49-55: Singleton training-state toggle is correct, but document recompilation behavior.

When training flips, you recompile with a different mode. Add a one-line comment so future changes don't break this contract.

src/axolotl/loaders/patch_manager.py (1)

435-441: Typo in comment.

“seperately” → “separately”.

Apply:

-            # TODO(MengqingCao): split these patches seperately
+            # TODO(MengqingCao): split these patches separately
src/axolotl/monkeypatch/tiled_mlp/base.py (1)

31-41: Forward tuple concatenation assumes dims [1, 0]; prefer recording actual concat dims.

Hardcoding [1, 0] risks silent shape bugs. The revised TiledMLP forward above records per-element concat dims in ctx.tuple_cat_dims. Mirror the same in MoE forward for consistency.

src/axolotl/monkeypatch/tiled_mlp/patch.py (3)

20-21: Model-specific DS path is fine; add a guard for non-DS runs of gpt_oss.

DeepSpeedTiledMLPMoE is selected for "gpt_oss" only under DS heuristics. If users set "gpt_oss" without DS, we fall back to local TiledMLP, which likely lacks MoE semantics. Consider a warning so users know MoE tiling needs DS.

Apply:

                 ) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true":
                     if model_type == "gpt_oss":
                         self._tiled_mlp_dist_impl = DeepSpeedTiledMLPMoE
                     else:
                         self._tiled_mlp_dist_impl = DeepSpeedTiledMLP
                 else:
+                    if model_type == "gpt_oss":
+                        LOG.warning("gpt_oss MoE selected without DeepSpeed; using local TiledMLP fallback", main_process_only=True)
                     self._tiled_mlp_dist_impl = TiledMLP

Also applies to: 66-71


38-39: torch.compile mode parity with flex-attn patch.

For 2.6 training, generic_mlp_forward may also benefit from mode="max-autotune-no-cudagraphs" and dynamic=False to match the compile guidance you applied for flex attention.


46-55: Shard heuristic can produce 0 or very large shard counts; clamp and validate.

When seqlen < hidden, ceil(seqlen/hidden) → 1, OK; but add a lower bound (>=1) and an upper bound (<=seqlen). Also ensure divisibility assumptions are removed (your base fixes above handle uneven chunks).

Apply:

             if cfg_num_shards is None:
-                num_shards = math.ceil(seqlen / hidden)
+                num_shards = max(1, min(seqlen, math.ceil(seqlen / max(1, hidden))))
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 6afba38 and 13586f8.

📒 Files selected for processing (4)
  • src/axolotl/loaders/patch_manager.py (1 hunks)
  • src/axolotl/monkeypatch/attention/flex_attn.py (2 hunks)
  • src/axolotl/monkeypatch/tiled_mlp/base.py (4 hunks)
  • src/axolotl/monkeypatch/tiled_mlp/patch.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/axolotl/loaders/patch_manager.py (1)
src/axolotl/monkeypatch/attention/flex_attn.py (1)
  • patch_flex_wrapper (15-77)
src/axolotl/monkeypatch/tiled_mlp/patch.py (1)
src/axolotl/monkeypatch/tiled_mlp/base.py (2)
  • DeepSpeedTiledMLPMoE (11-96)
  • TiledMLP (99-188)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
  • GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
  • GitHub Check: PyTest (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.8.0)
  • GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
  • GitHub Check: PyTest (3.11, 2.6.0)
🔇 Additional comments (2)
src/axolotl/loaders/patch_manager.py (1)

150-156: Good: flex-attn kwargs now flow only to the wrapper.

Import and pass-through look correct; no stray kwargs to unrelated patches.

src/axolotl/monkeypatch/tiled_mlp/base.py (1)

24-31: Minor: skip saving non-grad params.

ctx.compute_params = [p for p in compute_params if p.requires_grad] is good; consider guarding when compute_params can be None.

Would compute_params ever be None for MoE path? If yes, add or [] to avoid TypeError.

Comment on lines +27 to +45
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
with torch.no_grad():
output_shards = [fn(self, x_shard) for x_shard in x_shards]

ctx.is_tuple_output = isinstance(output_shards[0], tuple)
if isinstance(output_shards[0], tuple):
tuple_dim_idx = [1, 0]
output_unsharded = tuple(
torch.cat(
[output_shard[i] for output_shard in output_shards],
dim=tuple_dim_idx[i],
)
for i in range(len(output_shards[0]))
)
else:
output_unsharded = torch.cat(output_shards, dim=1)

return output_unsharded

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Backward ignores tuple outputs and slices grads incorrectly; fix tuple-aware sharding and uneven chunk sizes.

Issues:

  • Only grads[0] is propagated when output is a tuple; others are dropped.
  • shard_step = x_shards[0].numel() assumes equal chunks; torch.chunk can produce uneven sizes.
  • Slicing grads via flatten/narrow can mismatch per-dim sharding; shard along sequence dim instead.

Patch (DeepSpeedTiledMLPMoE.backward):

-        incoming_grad = grads[0]
-        x_grad = torch.zeros_like(x)
-        x_shards = list(torch.chunk(x, chunks=shards, dim=1))
-
-        shard_step = x_shards[0].numel()
-        for i, x_shard in enumerate(x_shards):
+        # grads can be a tuple matching forward outputs
+        incoming_grads = grads if is_tuple_output else (grads[0],)
+        x_grad = torch.zeros_like(x)
+        x_shards = list(torch.chunk(x, chunks=shards, dim=1))
+        # shard along sequence dim (dim=1); handle uneven chunks
+        seq_offsets = []
+        _off = 0
+        for xs in x_shards:
+            seq_offsets.append((_off, xs.size(1)))
+            _off += xs.size(1)
+
+        for i, x_shard in enumerate(x_shards):
             # Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run
             if compute_params is not None:
                 if i + 1 < shards:
                     for param in compute_params:
                         param.ds_grad_is_ready = False
                 else:
                     # last shard, can add the grad
                     for param in compute_params:
                         param.ds_grad_is_ready = True
 
             x_shard.requires_grad_(x_requires_grad)
 
-            shard_offset = i * shard_step
-            x_shard.grad = (
-                x_grad.view(-1)
-                .narrow(0, shard_offset, x_shard.numel())
-                .view_as(x_shard)
-            )
-            incoming_grad_shard = (
-                incoming_grad.view(-1)
-                .narrow(0, shard_offset, x_shard.numel())
-                .view_as(x_shard)
-            )
+            seq_off, seq_len = seq_offsets[i]
+            # assign grad view for this shard along sequence dim
+            x_shard.grad = x_grad.narrow(1, seq_off, seq_len)
+
             with torch.enable_grad():
                 output = fn(self, x_shard)
-            if is_tuple_output:
-                torch.autograd.backward(output[0], incoming_grad_shard)
-            else:
-                torch.autograd.backward(output, incoming_grad_shard)
+            if is_tuple_output:
+                # assume first tuple element is sharded along seq dim (=1);
+                # for additional elements, also slice along seq dim unless model dictates otherwise.
+                grad_inputs = []
+                for g in incoming_grads:
+                    grad_inputs.append(g.narrow(1, seq_off, seq_len))
+                torch.autograd.backward(output, grad_inputs)
+            else:
+                incoming_grad_shard = incoming_grads[0].narrow(1, seq_off, seq_len)
+                torch.autograd.backward(output, incoming_grad_shard)

If some tuple elements concatenate along a different dim, store tuple_cat_dims in ctx during forward and slice per-element accordingly (see TiledMLP patch below).

Also applies to: 60-96

🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/tiled_mlp/base.py around lines 27 to 45 (and
similarly for lines 60–96), the backward currently only propagates grads[0] for
tuple outputs, assumes equal chunk sizes, and slices grads using an incorrect
flatten/narrow approach; update forward to save tuple_cat_dims and the per-shard
lengths (e.g., list of x_shards[i].size(sequence_dim) or per-chunk numel
per-dim) into ctx, and in backward: iterate over all tuple elements and
reconstruct per-shard grad tuples (not just grads[0]), compute precise slice
offsets from the actual chunk sizes (handle uneven torch.chunk results) and
slice each tuple element along its saved concatenation dim (tuple_cat_dims)
instead of flatten/narrow, then call the original fn.backward (or invoke
autograd correctly) with the per-shard tuple of grads so every tuple component
and uneven chunk is handled correctly.

Comment on lines 119 to 136
x_shards = list(torch.chunk(x, chunks=shards, dim=1))
with torch.no_grad():
output_shards = [fn(self, x_shard) for x_shard in x_shards]
output_unsharded = torch.cat(output_shards, dim=1)
ctx.is_tuple_output = isinstance(output_shards[0], tuple)
if isinstance(output_shards[0], tuple):
tuple_dim_idx = [1, 0]
output_unsharded = tuple(
torch.cat(
[output_shard[i] for output_shard in output_shards],
dim=tuple_dim_idx[i],
)
for i in range(len(output_shards[0]))
)
else:
output_unsharded = torch.cat(output_shards, dim=1)

return output_unsharded

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Same tuple/backward issues in TiledMLP; plus gradient hook misuse causes N× accumulation.

Issues:

  • Same tuple grad loss and uneven chunking as above.
  • GradientAccumulator.install_hooks() is called per shard, registering multiple hooks on the same param. Each grad triggers all hooks → over-accumulation.

Fix tuple grads and install hooks once:

-        with torch.no_grad():
-            output_shards = [fn(self, x_shard) for x_shard in x_shards]
-        ctx.is_tuple_output = isinstance(output_shards[0], tuple)
-        if isinstance(output_shards[0], tuple):
-            tuple_dim_idx = [1, 0]
+        with torch.no_grad():
+            output_shards = [fn(self, x_shard) for x_shard in x_shards]
+        ctx.is_tuple_output = isinstance(output_shards[0], tuple)
+        if ctx.is_tuple_output:
+            # record concat dims per output element; default to seq dim (1)
+            tuple_cat_dims = [1 for _ in range(len(output_shards[0]))]
+            ctx.tuple_cat_dims = tuple_cat_dims
             output_unsharded = tuple(
                 torch.cat(
-                    [output_shard[i] for output_shard in output_shards],
-                    dim=tuple_dim_idx[i],
+                    [output_shard[i] for output_shard in output_shards],
+                    dim=tuple_cat_dims[i],
                 )
                 for i in range(len(output_shards[0]))
             )
         else:
             output_unsharded = torch.cat(output_shards, dim=1)
-        # Create a gradient accumulator for parameters
-        grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype)
+        # Create a gradient accumulator; install hooks once
+        grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype)
+        grad_accumulator.install_hooks()
 
-        shard_step = x_shards[0].numel()
-        for i, x_shard in enumerate(x_shards):
+        # slice along seq dim; handle uneven chunks
+        seq_offsets = []
+        _off = 0
+        for xs in x_shards:
+            seq_offsets.append((_off, xs.size(1)))
+            _off += xs.size(1)
+        for i, x_shard in enumerate(x_shards):
             x_shard.requires_grad_(x_requires_grad)
 
-            shard_offset = i * shard_step
-            x_shard.grad = (
-                x_grad.view(-1)
-                .narrow(0, shard_offset, x_shard.numel())
-                .view_as(x_shard)
-            )
-            incoming_grad_shard = (
-                incoming_grad.view(-1)
-                .narrow(0, shard_offset, x_shard.numel())
-                .view_as(x_shard)
-            )
+            seq_off, seq_len = seq_offsets[i]
+            x_shard.grad = x_grad.narrow(1, seq_off, seq_len)
 
-            # Install hooks for this shard
-            is_last_shard = i + 1 == shards
-            grad_accumulator.install_hooks(is_last_shard)
+            # Flip hook behavior only for the last shard
+            grad_accumulator.set_is_last_shard(i + 1 == shards)
 
             with torch.enable_grad():
                 output = fn(self, x_shard)
-            if is_tuple_output:
-                torch.autograd.backward(output[0], incoming_grad_shard)
-            else:
-                torch.autograd.backward(output, incoming_grad_shard)
+            if is_tuple_output:
+                grad_inputs = []
+                for idx, g in enumerate(grads):
+                    dim = getattr(ctx, "tuple_cat_dims", [1]*len(grads))[idx]
+                    if g is None:
+                        grad_inputs.append(None)
+                    elif dim == 1:
+                        grad_inputs.append(g.narrow(1, seq_off, seq_len))
+                    elif dim == 0:
+                        bsz = x_shard.size(0)
+                        grad_inputs.append(g.narrow(0, i * bsz, bsz))
+                    else:
+                        raise RuntimeError(f"Unsupported concat dim {dim} for tuple output")
+                torch.autograd.backward(output, grad_inputs)
+            else:
+                incoming_grad_shard = incoming_grad.narrow(1, seq_off, seq_len)
+                torch.autograd.backward(output, incoming_grad_shard)

Also applies to: 137-189

@codecov

codecov Bot commented Aug 29, 2025

Copy link
Copy Markdown

# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
# see https://github.com/pytorch/pytorch/issues/146260 for training
self.training = training
LOG.info(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found this helpful as sometimes compilation took a while and it was good to know training wasn't stuck

@winglian winglian merged commit 0094a2d into main Aug 29, 2025
13 of 16 checks passed
@winglian winglian deleted the flex-attn-kwargs branch August 29, 2025 17:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants