Skip to content

fix: qwen3-next to use fla causal-conv1d to support packing#3437

Merged
winglian merged 3 commits into
mainfrom
fix/qwen3-next-causal
Mar 3, 2026
Merged

fix: qwen3-next to use fla causal-conv1d to support packing#3437
winglian merged 3 commits into
mainfrom
fix/qwen3-next-causal

Conversation

@NanoCode012
Copy link
Copy Markdown
Collaborator

@NanoCode012 NanoCode012 commented Feb 26, 2026

Description

Since we ask users to uninstall triad's causal-conv1d, we've been fall backing to PyTorch's op, which doesn't handle cu seqlen. This PR fixes that as well as using a triton kernel for slightly more optimized performance.

Thanks to morphism for the report https://discord.com/channels/1104757954588196865/1104757955204743201/1476527723512987730

Context:

Hmm, though I notice that seq_idx is still None for the causal_conv1d part? Is that correct?

Ah, might be a tiny leak then?  Could maybe use fla.modules.convolution.causal_conv1d with cu_seqlens instead?

Breaking change

Fails hard on packing + no FLA

Motivation and Context

How has this been tested?

  • To manual run

AI Usage Disclaimer

Screenshots (if appropriate)

Types of changes

Social Handles (Optional)

Summary by CodeRabbit

Release Notes

  • Documentation

    • Simplified installation and setup instructions for clearer onboarding
    • Updated VRAM usage guidance with specific memory ranges for different configurations
  • Configuration Updates

    • Enhanced training parameters and optimization settings
    • Added new configuration flags for improved performance and memory efficiency
  • Performance Improvements

    • Optimized computation handling for better resource utilization

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Feb 26, 2026

📝 Walkthrough

Walkthrough

Updated Qwen3-Next example documentation and configuration to simplify installation steps, adjust LoRA training parameters (dropout to 0, expanded target modules), add MoE expert quantization, and enhance modeling with improved causal convolution handling and cu_seqlens computation.

Changes

Cohort / File(s) Summary
Qwen3-Next Example Configuration
examples/qwen3-next/README.md, examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
Simplified installation instructions; removed main/nightlies and Docker references; updated FLA dependency from 0.3.2 to 0.4.1; adjusted VRAM usage notes with variant-specific ranges; added quantize_moe_experts: true flag; changed lora_dropout from 0.05 to 0; expanded lora_target_modules to include attention projections and shared expert parameters; added kernel toggle flags (lora_mlp_kernel, lora_qkv_kernel, lora_o_kernel).
Qwen3-Next Modeling Patch
src/axolotl/monkeypatch/models/qwen3_next/modeling.py
Added conditional import and fallback handling for fla_causal_conv1d; optimized cu_seqlens computation for both precomputed and cached inference paths; enhanced tensor shape transformations with proper transposes for causal_conv1d_update; implemented fallback to PyTorch conv1d with warnings; added clarifying comments on shape expectations for inference vs. training paths.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • SalmanMohammadi
  • winglian
  • djsaunde
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 66.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: implementing fla causal-conv1d support for the qwen3-next model to enable packing with cumulative sequence lengths (cu_seqlens).
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
  • 📝 Generate docstrings (stacked PR)
  • 📝 Generate docstrings (commit on current branch)
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
  • Commit unit tests in branch fix/qwen3-next-causal

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@NanoCode012 NanoCode012 marked this pull request as ready for review March 3, 2026 09:18
@NanoCode012
Copy link
Copy Markdown
Collaborator Author

image

Depends on #3439

@NanoCode012 NanoCode012 added the scheduled_release This PR is slated for the upcoming release label Mar 3, 2026
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
examples/qwen3-next/README.md (1)

15-15: Make the causal-conv1d uninstall conditional or remove it entirely.

flash-linear-attention==0.4.1 ships with Triton conv1d implementations and does not require causal-conv1d to function. The forced pip3 uninstall -y is destructive to local environments where causal-conv1d may be needed for other packages.

Recommend updating the install step to:

pip3 install flash-linear-attention==0.4.1

If a user encounters a conflict, document causal-conv1d as an optional dependency that can be uninstalled only if explicitly needed for their setup (e.g., if using the [conv1d] extra with older compatibility requirements).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/qwen3-next/README.md` at line 15, Replace the destructive command
"pip3 uninstall -y causal-conv1d && pip3 install flash-linear-attention==0.4.1"
with a non-destructive install (just "pip3 install
flash-linear-attention==0.4.1") and update the README to note that if users
encounter a conflict they may optionally uninstall "causal-conv1d" (or remove
the package only when explicitly required for the user's environment or the
older [conv1d] extra); do not force uninstall by default and include a brief
note about when the optional uninstall is appropriate.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@src/axolotl/monkeypatch/models/qwen3_next/modeling.py`:
- Around line 197-205: The fallback to PyTorch conv1d should fail fast when
packed sequences are present: detect the presence of cu_seqlens (the
packed-sequence indicator used by this code path) before using the PyTorch
fallback in the causal_conv1d block (the section that currently calls
LOG.warning_once and then applies self.conv1d to mixed_qkv); if cu_seqlens (or
any packed-input flag passed into this function) is set, raise an explicit error
instead of continuing, otherwise keep the existing warning and apply
F.silu(self.conv1d(...)) to mixed_qkv as before. Ensure the error references the
causal_conv1d fallback and cu_seqlens so callers know packed sequences are
unsupported without the FLA kernel.

---

Nitpick comments:
In `@examples/qwen3-next/README.md`:
- Line 15: Replace the destructive command "pip3 uninstall -y causal-conv1d &&
pip3 install flash-linear-attention==0.4.1" with a non-destructive install (just
"pip3 install flash-linear-attention==0.4.1") and update the README to note that
if users encounter a conflict they may optionally uninstall "causal-conv1d" (or
remove the package only when explicitly required for the user's environment or
the older [conv1d] extra); do not force uninstall by default and include a brief
note about when the optional uninstall is appropriate.

ℹ️ Review info

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 18f26c1 and 4dccdfd.

📒 Files selected for processing (3)
  • examples/qwen3-next/README.md
  • examples/qwen3-next/qwen3-next-80b-a3b-qlora.yaml
  • src/axolotl/monkeypatch/models/qwen3_next/modeling.py

Comment on lines +197 to 205
# PyTorch fallback (no cu_seqlens support)
LOG.warning_once(
"FLA causal_conv1d not available. Falling back to PyTorch conv1d "
"which does not support cu_seqlens for packed sequences."
)
mixed_qkv = mixed_qkv.transpose(1, 2)
mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
mixed_qkv = mixed_qkv.transpose(1, 2)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fail fast on packed inputs when FLA causal kernel is unavailable.

Current fallback continues execution even when cu_seqlens is present, which can silently produce incorrect results for packed sequences.

Suggested fix
             else:
-                # PyTorch fallback (no cu_seqlens support)
-                LOG.warning_once(
-                    "FLA causal_conv1d not available. Falling back to PyTorch conv1d "
-                    "which does not support cu_seqlens for packed sequences."
-                )
+                # PyTorch fallback (no cu_seqlens support)
+                if cu_seqlens is not None:
+                    raise RuntimeError(
+                        "Packed sequences require fla.modules.convolution.causal_conv1d "
+                        "(cu_seqlens support). Install flash-linear-attention or disable packing."
+                    )
+                LOG.warning_once(
+                    "FLA causal_conv1d not available. Falling back to PyTorch conv1d."
+                )
                 mixed_qkv = mixed_qkv.transpose(1, 2)
                 mixed_qkv = F.silu(self.conv1d(mixed_qkv)[:, :, :seq_len])
                 mixed_qkv = mixed_qkv.transpose(1, 2)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/axolotl/monkeypatch/models/qwen3_next/modeling.py` around lines 197 -
205, The fallback to PyTorch conv1d should fail fast when packed sequences are
present: detect the presence of cu_seqlens (the packed-sequence indicator used
by this code path) before using the PyTorch fallback in the causal_conv1d block
(the section that currently calls LOG.warning_once and then applies self.conv1d
to mixed_qkv); if cu_seqlens (or any packed-input flag passed into this
function) is set, raise an explicit error instead of continuing, otherwise keep
the existing warning and apply F.silu(self.conv1d(...)) to mixed_qkv as before.
Ensure the error references the causal_conv1d fallback and cu_seqlens so callers
know packed sequences are unsupported without the FLA kernel.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Mar 3, 2026

📖 Documentation Preview: https://69a6b1d2defcc2e8c1aef468--resonant-treacle-0fd729.netlify.app

Deployed on Netlify from commit 025ff7b

@codecov
Copy link
Copy Markdown

codecov Bot commented Mar 3, 2026

Codecov Report

❌ Patch coverage is 22.22222% with 14 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
.../axolotl/monkeypatch/models/qwen3_next/modeling.py 22.22% 14 Missing ⚠️

📢 Thoughts on this report? Let us know!

@winglian winglian merged commit e672d37 into main Mar 3, 2026
19 of 20 checks passed
@winglian winglian deleted the fix/qwen3-next-causal branch March 3, 2026 14:26
@winglian winglian removed scheduled_release This PR is slated for the upcoming release ready to merge labels Mar 22, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants