fix use of flex attn kwargs and add support for tiledmlp for GPT-OSS#3116
Conversation
|
Important Review skippedAuto incremental reviews are disabled on this repository. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the 📝 WalkthroughWalkthroughIntroduces version-aware flex attention compilation and removes mask patching. Adds DeepSpeed-compatible tiled MoE autograd path with tuple-aware shard handling and gradient accumulation. Updates tiled MLP patching to select MoE-specific implementation for gpt_oss models. Adjusts patch manager to pass compile kwargs to flex wrapper only. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
✨ Finishing Touches🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
src/axolotl/monkeypatch/attention/flex_attn.py (2)
17-21: Unreachable version branches due to early return; clarify intent or remove the guard.Returning early when torch is not 2.6 makes the 2.5.1 and fallback branches in WrappedFlexAttention.init unreachable. Either:
- Keep the wrapper active for all versions and select behavior inside, or
- Limit the wrapper to 2.6.x only and delete the <=2.5.1 path.
I recommend the first for consistency. Patch:
- is_torch_2_6 = torch.__version__.startswith("2.6") - - if not is_torch_2_6: - return + # Apply wrapper across versions; per-version behavior is handled below.Optionally keep a narrow gate if you explicitly don't want to affect non-2.6: then also delete the
is_torch_less_or_equal("2.5.1")branch to avoid dead code.Also applies to: 51-68
203-210: Duplicate assignment; set once.
make_flex_block_causal_maskis assigned twice to the same target.Apply:
- sys.modules[ - n - ].make_flex_block_causal_mask = patched_make_flex_block_causal_mask - sys.modules[ - n - ].make_flex_block_causal_mask = patched_make_flex_block_causal_mask + sys.modules[n].make_flex_block_causal_mask = patched_make_flex_block_causal_masksrc/axolotl/loaders/patch_manager.py (1)
50-55: Flex attention patching is disabled.
_apply_flex_attention_patches()is commented out, so none of the new flex-attn kwargs/version logic runs. This contradicts the PR goal.Apply:
- # self._apply_flex_attention_patches() + self._apply_flex_attention_patches()src/axolotl/monkeypatch/tiled_mlp/base.py (1)
191-257: GradientAccumulator installs hooks per shard; install once and gate final assignment.Refactor to avoid N× accumulation:
class GradientAccumulator: @@ - def install_hooks(self, is_last_shard: bool): - """Install gradient hooks that accumulate gradients in higher precision""" + def install_hooks(self): + """Install gradient hooks once; use self.is_last_shard to gate final assignment""" + self.is_last_shard = False @@ - def create_hook(param): + def create_hook(param): def hook(grad): with self.lock: grad_to_accum_dtype = grad.to(self.grad_accumulation_dtype) scaled_grad = grad_to_accum_dtype * self.gradient_scale @@ - # Only assign the averaged gradient on the last shard - if is_last_shard: + # Only assign the averaged gradient on the last shard + if self.is_last_shard: param.grad = self.accumulated_grads[param].to(param.dtype) return param.grad return None @@ - for param in self.params: + for param in self.params: if param.requires_grad: hook = param.register_hook(create_hook(param)) self.hooks.append(hook) + + def set_is_last_shard(self, is_last: bool): + self.is_last_shard = is_last
🧹 Nitpick comments (6)
src/axolotl/monkeypatch/attention/flex_attn.py (1)
49-55: Singleton training-state toggle is correct, but document recompilation behavior.When
trainingflips, you recompile with a different mode. Add a one-line comment so future changes don't break this contract.src/axolotl/loaders/patch_manager.py (1)
435-441: Typo in comment.“seperately” → “separately”.
Apply:
- # TODO(MengqingCao): split these patches seperately + # TODO(MengqingCao): split these patches separatelysrc/axolotl/monkeypatch/tiled_mlp/base.py (1)
31-41: Forward tuple concatenation assumes dims [1, 0]; prefer recording actual concat dims.Hardcoding
[1, 0]risks silent shape bugs. The revised TiledMLP forward above records per-element concat dims inctx.tuple_cat_dims. Mirror the same in MoE forward for consistency.src/axolotl/monkeypatch/tiled_mlp/patch.py (3)
20-21: Model-specific DS path is fine; add a guard for non-DS runs of gpt_oss.
DeepSpeedTiledMLPMoEis selected for"gpt_oss"only under DS heuristics. If users set"gpt_oss"without DS, we fall back to localTiledMLP, which likely lacks MoE semantics. Consider a warning so users know MoE tiling needs DS.Apply:
) or os.environ.get("ACCELERATE_USE_DEEPSPEED", "false") == "true": if model_type == "gpt_oss": self._tiled_mlp_dist_impl = DeepSpeedTiledMLPMoE else: self._tiled_mlp_dist_impl = DeepSpeedTiledMLP else: + if model_type == "gpt_oss": + LOG.warning("gpt_oss MoE selected without DeepSpeed; using local TiledMLP fallback", main_process_only=True) self._tiled_mlp_dist_impl = TiledMLPAlso applies to: 66-71
38-39: torch.compile mode parity with flex-attn patch.For 2.6 training,
generic_mlp_forwardmay also benefit frommode="max-autotune-no-cudagraphs"anddynamic=Falseto match the compile guidance you applied for flex attention.
46-55: Shard heuristic can produce 0 or very large shard counts; clamp and validate.When
seqlen < hidden,ceil(seqlen/hidden)→ 1, OK; but add a lower bound (>=1) and an upper bound (<=seqlen). Also ensure divisibility assumptions are removed (your base fixes above handle uneven chunks).Apply:
if cfg_num_shards is None: - num_shards = math.ceil(seqlen / hidden) + num_shards = max(1, min(seqlen, math.ceil(seqlen / max(1, hidden))))
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (4)
src/axolotl/loaders/patch_manager.py(1 hunks)src/axolotl/monkeypatch/attention/flex_attn.py(2 hunks)src/axolotl/monkeypatch/tiled_mlp/base.py(4 hunks)src/axolotl/monkeypatch/tiled_mlp/patch.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
src/axolotl/loaders/patch_manager.py (1)
src/axolotl/monkeypatch/attention/flex_attn.py (1)
patch_flex_wrapper(15-77)
src/axolotl/monkeypatch/tiled_mlp/patch.py (1)
src/axolotl/monkeypatch/tiled_mlp/base.py (2)
DeepSpeedTiledMLPMoE(11-96)TiledMLP(99-188)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (6)
- GitHub Check: PyTest from Source Dist (3.11, 2.8.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.6.0)
- GitHub Check: PyTest (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.8.0)
- GitHub Check: PyTest from Source Dist (3.11, 2.7.1)
- GitHub Check: PyTest (3.11, 2.6.0)
🔇 Additional comments (2)
src/axolotl/loaders/patch_manager.py (1)
150-156: Good: flex-attn kwargs now flow only to the wrapper.Import and pass-through look correct; no stray kwargs to unrelated patches.
src/axolotl/monkeypatch/tiled_mlp/base.py (1)
24-31: Minor: skip saving non-grad params.
ctx.compute_params = [p for p in compute_params if p.requires_grad]is good; consider guarding whencompute_paramscan beNone.Would
compute_paramsever beNonefor MoE path? If yes, addor []to avoidTypeError.
| x_shards = list(torch.chunk(x, chunks=shards, dim=1)) | ||
| with torch.no_grad(): | ||
| output_shards = [fn(self, x_shard) for x_shard in x_shards] | ||
|
|
||
| ctx.is_tuple_output = isinstance(output_shards[0], tuple) | ||
| if isinstance(output_shards[0], tuple): | ||
| tuple_dim_idx = [1, 0] | ||
| output_unsharded = tuple( | ||
| torch.cat( | ||
| [output_shard[i] for output_shard in output_shards], | ||
| dim=tuple_dim_idx[i], | ||
| ) | ||
| for i in range(len(output_shards[0])) | ||
| ) | ||
| else: | ||
| output_unsharded = torch.cat(output_shards, dim=1) | ||
|
|
||
| return output_unsharded | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
Backward ignores tuple outputs and slices grads incorrectly; fix tuple-aware sharding and uneven chunk sizes.
Issues:
- Only
grads[0]is propagated when output is a tuple; others are dropped. shard_step = x_shards[0].numel()assumes equal chunks;torch.chunkcan produce uneven sizes.- Slicing grads via flatten/narrow can mismatch per-dim sharding; shard along sequence dim instead.
Patch (DeepSpeedTiledMLPMoE.backward):
- incoming_grad = grads[0]
- x_grad = torch.zeros_like(x)
- x_shards = list(torch.chunk(x, chunks=shards, dim=1))
-
- shard_step = x_shards[0].numel()
- for i, x_shard in enumerate(x_shards):
+ # grads can be a tuple matching forward outputs
+ incoming_grads = grads if is_tuple_output else (grads[0],)
+ x_grad = torch.zeros_like(x)
+ x_shards = list(torch.chunk(x, chunks=shards, dim=1))
+ # shard along sequence dim (dim=1); handle uneven chunks
+ seq_offsets = []
+ _off = 0
+ for xs in x_shards:
+ seq_offsets.append((_off, xs.size(1)))
+ _off += xs.size(1)
+
+ for i, x_shard in enumerate(x_shards):
# Tell deepspeed not to add a new grad to its ipg bucket until the last shard is run
if compute_params is not None:
if i + 1 < shards:
for param in compute_params:
param.ds_grad_is_ready = False
else:
# last shard, can add the grad
for param in compute_params:
param.ds_grad_is_ready = True
x_shard.requires_grad_(x_requires_grad)
- shard_offset = i * shard_step
- x_shard.grad = (
- x_grad.view(-1)
- .narrow(0, shard_offset, x_shard.numel())
- .view_as(x_shard)
- )
- incoming_grad_shard = (
- incoming_grad.view(-1)
- .narrow(0, shard_offset, x_shard.numel())
- .view_as(x_shard)
- )
+ seq_off, seq_len = seq_offsets[i]
+ # assign grad view for this shard along sequence dim
+ x_shard.grad = x_grad.narrow(1, seq_off, seq_len)
+
with torch.enable_grad():
output = fn(self, x_shard)
- if is_tuple_output:
- torch.autograd.backward(output[0], incoming_grad_shard)
- else:
- torch.autograd.backward(output, incoming_grad_shard)
+ if is_tuple_output:
+ # assume first tuple element is sharded along seq dim (=1);
+ # for additional elements, also slice along seq dim unless model dictates otherwise.
+ grad_inputs = []
+ for g in incoming_grads:
+ grad_inputs.append(g.narrow(1, seq_off, seq_len))
+ torch.autograd.backward(output, grad_inputs)
+ else:
+ incoming_grad_shard = incoming_grads[0].narrow(1, seq_off, seq_len)
+ torch.autograd.backward(output, incoming_grad_shard)If some tuple elements concatenate along a different dim, store tuple_cat_dims in ctx during forward and slice per-element accordingly (see TiledMLP patch below).
Also applies to: 60-96
🤖 Prompt for AI Agents
In src/axolotl/monkeypatch/tiled_mlp/base.py around lines 27 to 45 (and
similarly for lines 60–96), the backward currently only propagates grads[0] for
tuple outputs, assumes equal chunk sizes, and slices grads using an incorrect
flatten/narrow approach; update forward to save tuple_cat_dims and the per-shard
lengths (e.g., list of x_shards[i].size(sequence_dim) or per-chunk numel
per-dim) into ctx, and in backward: iterate over all tuple elements and
reconstruct per-shard grad tuples (not just grads[0]), compute precise slice
offsets from the actual chunk sizes (handle uneven torch.chunk results) and
slice each tuple element along its saved concatenation dim (tuple_cat_dims)
instead of flatten/narrow, then call the original fn.backward (or invoke
autograd correctly) with the per-shard tuple of grads so every tuple component
and uneven chunk is handled correctly.
| x_shards = list(torch.chunk(x, chunks=shards, dim=1)) | ||
| with torch.no_grad(): | ||
| output_shards = [fn(self, x_shard) for x_shard in x_shards] | ||
| output_unsharded = torch.cat(output_shards, dim=1) | ||
| ctx.is_tuple_output = isinstance(output_shards[0], tuple) | ||
| if isinstance(output_shards[0], tuple): | ||
| tuple_dim_idx = [1, 0] | ||
| output_unsharded = tuple( | ||
| torch.cat( | ||
| [output_shard[i] for output_shard in output_shards], | ||
| dim=tuple_dim_idx[i], | ||
| ) | ||
| for i in range(len(output_shards[0])) | ||
| ) | ||
| else: | ||
| output_unsharded = torch.cat(output_shards, dim=1) | ||
|
|
||
| return output_unsharded | ||
|
|
There was a problem hiding this comment.
🛠️ Refactor suggestion
Same tuple/backward issues in TiledMLP; plus gradient hook misuse causes N× accumulation.
Issues:
- Same tuple grad loss and uneven chunking as above.
GradientAccumulator.install_hooks()is called per shard, registering multiple hooks on the same param. Each grad triggers all hooks → over-accumulation.
Fix tuple grads and install hooks once:
- with torch.no_grad():
- output_shards = [fn(self, x_shard) for x_shard in x_shards]
- ctx.is_tuple_output = isinstance(output_shards[0], tuple)
- if isinstance(output_shards[0], tuple):
- tuple_dim_idx = [1, 0]
+ with torch.no_grad():
+ output_shards = [fn(self, x_shard) for x_shard in x_shards]
+ ctx.is_tuple_output = isinstance(output_shards[0], tuple)
+ if ctx.is_tuple_output:
+ # record concat dims per output element; default to seq dim (1)
+ tuple_cat_dims = [1 for _ in range(len(output_shards[0]))]
+ ctx.tuple_cat_dims = tuple_cat_dims
output_unsharded = tuple(
torch.cat(
- [output_shard[i] for output_shard in output_shards],
- dim=tuple_dim_idx[i],
+ [output_shard[i] for output_shard in output_shards],
+ dim=tuple_cat_dims[i],
)
for i in range(len(output_shards[0]))
)
else:
output_unsharded = torch.cat(output_shards, dim=1)- # Create a gradient accumulator for parameters
- grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype)
+ # Create a gradient accumulator; install hooks once
+ grad_accumulator = GradientAccumulator(compute_params, shards, dtype=x.dtype)
+ grad_accumulator.install_hooks()
- shard_step = x_shards[0].numel()
- for i, x_shard in enumerate(x_shards):
+ # slice along seq dim; handle uneven chunks
+ seq_offsets = []
+ _off = 0
+ for xs in x_shards:
+ seq_offsets.append((_off, xs.size(1)))
+ _off += xs.size(1)
+ for i, x_shard in enumerate(x_shards):
x_shard.requires_grad_(x_requires_grad)
- shard_offset = i * shard_step
- x_shard.grad = (
- x_grad.view(-1)
- .narrow(0, shard_offset, x_shard.numel())
- .view_as(x_shard)
- )
- incoming_grad_shard = (
- incoming_grad.view(-1)
- .narrow(0, shard_offset, x_shard.numel())
- .view_as(x_shard)
- )
+ seq_off, seq_len = seq_offsets[i]
+ x_shard.grad = x_grad.narrow(1, seq_off, seq_len)
- # Install hooks for this shard
- is_last_shard = i + 1 == shards
- grad_accumulator.install_hooks(is_last_shard)
+ # Flip hook behavior only for the last shard
+ grad_accumulator.set_is_last_shard(i + 1 == shards)
with torch.enable_grad():
output = fn(self, x_shard)
- if is_tuple_output:
- torch.autograd.backward(output[0], incoming_grad_shard)
- else:
- torch.autograd.backward(output, incoming_grad_shard)
+ if is_tuple_output:
+ grad_inputs = []
+ for idx, g in enumerate(grads):
+ dim = getattr(ctx, "tuple_cat_dims", [1]*len(grads))[idx]
+ if g is None:
+ grad_inputs.append(None)
+ elif dim == 1:
+ grad_inputs.append(g.narrow(1, seq_off, seq_len))
+ elif dim == 0:
+ bsz = x_shard.size(0)
+ grad_inputs.append(g.narrow(0, i * bsz, bsz))
+ else:
+ raise RuntimeError(f"Unsupported concat dim {dim} for tuple output")
+ torch.autograd.backward(output, grad_inputs)
+ else:
+ incoming_grad_shard = incoming_grad.narrow(1, seq_off, seq_len)
+ torch.autograd.backward(output, incoming_grad_shard)Also applies to: 137-189
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
| # cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs" | ||
| # see https://github.com/pytorch/pytorch/issues/146260 for training | ||
| self.training = training | ||
| LOG.info( |
There was a problem hiding this comment.
I found this helpful as sometimes compilation took a while and it was good to know training wasn't stuck
f118a49 to
56a76a7
Compare
Description
Motivation and Context
How has this been tested?
Screenshots (if appropriate)
Types of changes
Social Handles (Optional)
Summary by CodeRabbit