Skip to content

LoRA bridge & merge for Qwen3.5#2736

Merged
cuichenx merged 3 commits intoNVIDIA-NeMo:mainfrom
HollowMan6:qwen3.5
Mar 20, 2026
Merged

LoRA bridge & merge for Qwen3.5#2736
cuichenx merged 3 commits intoNVIDIA-NeMo:mainfrom
HollowMan6:qwen3.5

Conversation

@HollowMan6
Copy link
Contributor

@HollowMan6 HollowMan6 commented Mar 10, 2026

What does this PR do ?

Support LoRA bridge & merge for Qwen3.5

Tested with target_modules='["in_proj","out_proj","linear_proj","linear_fc1","linear_fc2","router"]'

Changelog

  • handle HF base names without .weight for LoRA suffixing
  • add GDN in-proj split logic for fused adapters
  • support packed expert LoRA stacking in streaming
  • fix confusion with mamba layers when saving checkpoints, now identify mamba using "mixer.in_proj"
image

GitHub Actions CI

See the CI sectionin the Contributing doc for how to trigger the CI. A Nvidia developer will need to approve and trigger the CI for external contributors.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)

Test results

...
│ model.visual.patch_embed.proj.weight                              │ (1152, 3, 2, 16, 16) │ bfloat16 │ cpu    │       ✅       │
│ model.visual.pos_embed.weight                                     │ (2304, 1152)         │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.0.mlp.gate.weight                     │ (256, 2048)          │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.0.mlp.shared_expert_gate.weight       │ (1, 2048)            │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.0.mlp.shared_expert.gate_proj.weight  │ (512, 2048)          │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.0.mlp.shared_expert.up_proj.weight    │ (512, 2048)          │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.0.mlp.shared_expert.down_proj.weight  │ (2048, 512)          │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.0.post_attention_layernorm.weight     │ (2048,)              │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.0.linear_attn.A_log                   │ (32,)                │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.0.linear_attn.conv1d.weight           │ (8192, 1, 4)         │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.0.linear_attn.dt_bias                 │ (32,)                │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.0.input_layernorm.weight              │ (2048,)              │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.0.linear_attn.in_proj_qkv.weight      │ (8192, 2048)         │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.0.linear_attn.in_proj_z.weight        │ (4096, 2048)         │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.0.linear_attn.in_proj_b.weight        │ (32, 2048)           │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.0.linear_attn.in_proj_a.weight        │ (32, 2048)           │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.0.linear_attn.norm.weight             │ (128,)               │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.0.linear_attn.out_proj.weight         │ (2048, 4096)         │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.0.mlp.linear_fc1.bias                         │ (4304,)              │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.0.norm2.bias                                  │ (1152,)              │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.0.norm2.weight                                │ (1152,)              │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.0.mlp.linear_fc1.weight                       │ (4304, 1152)         │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.0.mlp.linear_fc2.bias                         │ (1152,)              │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.0.mlp.linear_fc2.weight                       │ (1152, 4304)         │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.0.attn.proj.bias                              │ (1152,)              │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.0.attn.proj.weight                            │ (1152, 1152)         │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.0.attn.qkv.bias                               │ (3456,)              │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.0.norm1.bias                                  │ (1152,)              │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.0.norm1.weight                                │ (1152,)              │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.0.attn.qkv.weight                             │ (3456, 1152)         │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.0.mlp.experts.gate_up_proj            │ (256, 1024, 2048)    │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.0.mlp.experts.down_proj               │ (256, 2048, 512)     │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.1.mlp.gate.weight                     │ (256, 2048)          │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.1.mlp.shared_expert_gate.weight       │ (1, 2048)            │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.1.mlp.shared_expert.gate_proj.weight  │ (512, 2048)          │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.1.mlp.shared_expert.up_proj.weight    │ (512, 2048)          │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.1.mlp.shared_expert.down_proj.weight  │ (2048, 512)          │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.1.post_attention_layernorm.weight     │ (2048,)              │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.1.linear_attn.A_log                   │ (32,)                │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.1.linear_attn.conv1d.weight           │ (8192, 1, 4)         │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.1.linear_attn.dt_bias                 │ (32,)                │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.1.input_layernorm.weight              │ (2048,)              │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.1.linear_attn.in_proj_qkv.weight      │ (8192, 2048)         │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.1.linear_attn.in_proj_z.weight        │ (4096, 2048)         │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.1.linear_attn.in_proj_b.weight        │ (32, 2048)           │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.1.linear_attn.in_proj_a.weight        │ (32, 2048)           │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.1.linear_attn.norm.weight             │ (128,)               │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.1.linear_attn.out_proj.weight         │ (2048, 4096)         │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.1.mlp.linear_fc1.bias                         │ (4304,)              │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.1.norm2.bias                                  │ (1152,)              │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.1.norm2.weight                                │ (1152,)              │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.1.mlp.linear_fc1.weight                       │ (4304, 1152)         │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.1.mlp.linear_fc2.bias                         │ (1152,)              │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.1.mlp.linear_fc2.weight                       │ (1152, 4304)         │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.1.attn.proj.bias                              │ (1152,)              │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.1.attn.proj.weight                            │ (1152, 1152)         │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.1.attn.qkv.bias                               │ (3456,)              │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.1.norm1.bias                                  │ (1152,)              │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.1.norm1.weight                                │ (1152,)              │ bfloat16 │ cpu    │       ✅       │
│ model.visual.blocks.1.attn.qkv.weight                             │ (3456, 1152)         │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.1.mlp.experts.gate_up_proj            │ (256, 1024, 2048)    │ bfloat16 │ cpu    │       ✅       │
│ model.language_model.layers.1.mlp.experts.down_proj               │ (256, 2048, 512)     │ bfloat16 │ cpu    │       ✅       │
...

✅ Verification passed: 1026 tensors match.
💾 Saving 830 adapter tensors to adapter_weights/demo_lora.safetensors ...
✅ Done! You can now load the adapters independently of the base model.

Summary by CodeRabbit

  • New Features

    • Extended model weight conversion support for GDN-based architectures, enabling improved handling of projection weight transformations across different framework representations.
  • Tests

    • Expanded test coverage for weight conversion and merging scenarios, including new test cases for expert weight handling and framework compatibility pathways.

Copilot AI review requested due to automatic review settings March 10, 2026 22:59
@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 10, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 10, 2026

📝 Walkthrough

Walkthrough

The changes extend LoRA/PEFT conversion logic to support GDN in_proj weight handling by introducing splitting and merging of fused Megatron tensors into per-base HF weight components (qkv, z, b, a) during streaming and merging operations, alongside comprehensive test coverage for these conversion paths.

Changes

Cohort / File(s) Summary
GDN in_proj Conversion Support
src/megatron/bridge/models/conversion/peft_bridge.py
Adds GDN in_proj weight splitting/merging logic to streaming and merging pathways; introduces helper functions _is_gdn_in_proj_split and _split_gdn_in_proj_linear_out_weight; extends base name resolution to tolerate missing .weight suffix; adds packed expert mode for conditional weight stacking; integrates new GDN splitting utilities from param_mapping.
Test Coverage Expansion
tests/unit_tests/models/test_model_bridge_lora.py
Adds pytest fixture for parallel state stubbing; imports GDN-specific utilities; introduces new test cases covering empty returns, grouped experts with missing expert_idx, LoRA parameter naming without weight suffix, packed expert weight streaming, and round-trip GDN in_proj splitting scenarios.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 2

❌ Failed checks (2 inconclusive)

Check name Status Explanation Resolution
Title check ❓ Inconclusive The title 'LoRA bridge & merge for Qwen3.5' is vague and does not clearly describe the specific technical changes made in the changeset, such as GDN in_proj handling, packed expert support, or base name resolution adjustments. Consider using a more specific title that captures the primary technical changes, such as 'Add GDN in_proj LoRA support and packed expert stacking for Qwen3.5' or similar phrasing that conveys the key improvements.
Test Results For Major Changes ❓ Inconclusive PR introduces major changes (235+ lines) to GDN handling and packed expert LoRA support with 207+ new test lines, but lacks documented test results or regression testing evidence. Provide documented test results showing new tests pass, evidence existing tests remain passing, and specific test assertions validating GDN in_proj split and packed expert stacking functionality.
✅ Passed checks (2 passed)
Check name Status Explanation
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment
📝 Coding Plan
  • Generate coding plan for human review comments

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can use Trivy to scan for security misconfigurations and secrets in Infrastructure as Code files.

Add a .trivyignore file to your project to customize which findings Trivy reports.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (3)
src/megatron/bridge/models/conversion/peft_bridge.py (3)

730-731: Defensive check is appropriate but consider logging for debuggability.

The continue silently skips when expert tensors are unexpectedly empty. While this prevents crashes, it could mask configuration issues in production.

🔧 Add debug logging
                 if not per_expert_linear_in or not per_expert_linear_out:
+                    # Unexpected: packed expert mode but no expert tensors found
                     continue
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/models/conversion/peft_bridge.py` around lines 730 - 731,
The silent continue should log missing expert tensors for debuggability: inside
the loop where the check uses per_expert_linear_in and per_expert_linear_out (in
peft_bridge.py), replace the silent continue with a debug or warning log that
includes which tensor(s) are empty and contextual identifiers (e.g., layer name,
expert index, or the surrounding module/variable names available in that scope)
and then continue; ensure the log uses the existing logger (or add one) and
includes values or shapes of per_expert_linear_in and per_expert_linear_out to
aid troubleshooting.

234-240: Strengthen detection by requiring all four GDN components.

The current check verifies the required tokens exist across any names, but doesn't ensure each token appears exactly once. This could produce false positives if a single name contains multiple component substrings.

🔧 Suggested refinement
     def _is_gdn_in_proj_split(self, hf_weight_names: Iterable[str]) -> bool:
         """Check whether the provided HF names correspond to split GDN in_proj weights."""

         names = list(hf_weight_names)
+        if len(names) != 4:
+            return False
         required = {"in_proj_qkv", "in_proj_z", "in_proj_b", "in_proj_a"}
         discovered = {token for name in names for token in required if token in name}
         return discovered == required and all("linear_attn" in name for name in names)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/models/conversion/peft_bridge.py` around lines 234 - 240,
The _is_gdn_in_proj_split check should require each GDN component token to
appear exactly once and not be co-located in a single name; change the logic in
_is_gdn_in_proj_split to count occurrences of each token in required =
{"in_proj_qkv","in_proj_z","in_proj_b","in_proj_a"} across hf_weight_names and
return True only if every token’s count == 1, no single name contains more than
one required token, and all names include "linear_attn". This ensures all four
components are present exactly once and avoids false positives when multiple
tokens appear in one weight name.

754-754: Assertion could be more informative for debugging.

If this assertion fails in production, the error message won't indicate which base name or adapter caused the issue.

🔧 Enhanced assertion message
-                        assert per_base is not None, "Expected fused adapter split for expert LoRA"
+                        assert per_base is not None, (
+                            f"Expected fused adapter split for expert LoRA, "
+                            f"base_prefix={adapter_task.global_base_prefix}"
+                        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/megatron/bridge/models/conversion/peft_bridge.py` at line 754, The
assertion "assert per_base is not None, 'Expected fused adapter split for expert
LoRA'" is not informative; update the check around per_base in peft_bridge.py
(the code that handles fused adapter split for expert LoRA) to raise a clearer
error or assertion that includes identifying context such as the base name and
adapter name (e.g., include variables like base_name, adapter_name, or the key
used to look up per_base) so the message reads something like "Expected fused
adapter split for expert LoRA: missing per_base for base '<base_name>' adapter
'<adapter_name>'"; locate the assertion near per_base usage in the function
handling fused adapter split and replace it with an assertion or ValueError that
interpolates those identifiers.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@src/megatron/bridge/models/conversion/peft_bridge.py`:
- Around line 730-731: The silent continue should log missing expert tensors for
debuggability: inside the loop where the check uses per_expert_linear_in and
per_expert_linear_out (in peft_bridge.py), replace the silent continue with a
debug or warning log that includes which tensor(s) are empty and contextual
identifiers (e.g., layer name, expert index, or the surrounding module/variable
names available in that scope) and then continue; ensure the log uses the
existing logger (or add one) and includes values or shapes of
per_expert_linear_in and per_expert_linear_out to aid troubleshooting.
- Around line 234-240: The _is_gdn_in_proj_split check should require each GDN
component token to appear exactly once and not be co-located in a single name;
change the logic in _is_gdn_in_proj_split to count occurrences of each token in
required = {"in_proj_qkv","in_proj_z","in_proj_b","in_proj_a"} across
hf_weight_names and return True only if every token’s count == 1, no single name
contains more than one required token, and all names include "linear_attn". This
ensures all four components are present exactly once and avoids false positives
when multiple tokens appear in one weight name.
- Line 754: The assertion "assert per_base is not None, 'Expected fused adapter
split for expert LoRA'" is not informative; update the check around per_base in
peft_bridge.py (the code that handles fused adapter split for expert LoRA) to
raise a clearer error or assertion that includes identifying context such as the
base name and adapter name (e.g., include variables like base_name,
adapter_name, or the key used to look up per_base) so the message reads
something like "Expected fused adapter split for expert LoRA: missing per_base
for base '<base_name>' adapter '<adapter_name>'"; locate the assertion near
per_base usage in the function handling fused adapter split and replace it with
an assertion or ValueError that interpolates those identifiers.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 0fba1075-a512-42a1-9585-94465f21abe3

📥 Commits

Reviewing files that changed from the base of the PR and between de93536 and c1c3951.

📒 Files selected for processing (2)
  • src/megatron/bridge/models/conversion/peft_bridge.py
  • tests/unit_tests/models/test_model_bridge_lora.py

Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds Qwen3.5-specific LoRA handling to the Megatron↔HF bridge, covering HF parameter naming that omits .weight, GDN in-proj fused/split adapter behavior, and packed-expert LoRA stacking during streaming export.

Changes:

  • Allow LoRA param naming/resolution when HF base names don’t end in .weight (notably Qwen3.5 MoE expert params).
  • Add GDN in-proj fused-adapter split/merge support (split in streaming; merge path for HF export).
  • Add streaming support for “packed expert” LoRA weights (stacked along expert dim) when HF expert names don’t include experts.N.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
src/megatron/bridge/models/conversion/peft_bridge.py Extends adapter name resolution, adds GDN in-proj split/merge paths, and implements packed-expert LoRA streaming logic.
tests/unit_tests/models/test_model_bridge_lora.py Adds unit tests for missing .weight suffix cases, packed expert streaming, grouped-expert merge edge case, and GDN split roundtrip.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@yaoyu-33 yaoyu-33 added t-lora model-qwen area:peft Parameter-efficient fine-tuning (LoRA, adapters) needs-review PR is ready for code review and waiting on a reviewer and removed model-qwen t-lora community-request labels Mar 11, 2026
Copy link
Contributor

@cuichenx cuichenx left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @HollowMan6 thanks for the contribution! Were you able to verify if the merged model and HF converted adapters are both correct? If so could you paste the verification results in the PR description?

From what I understand, MBridge's LoRA implementation for grouped MLP is different from HF's (we apply one adapter per grouped MLP while HF applies it per expert), so merging should work but I'm not sure about conversion to HF adapters.

@HollowMan6
Copy link
Contributor Author

HollowMan6 commented Mar 13, 2026

Hi @cuichenx ! Thank you for leaving comments. I just add more fixes 4928162 & 7d27a3e and this PR should be ready for further review. I have runned Megatron-Bridge/examples/conversion/stream_adapter_weights.py with modification to test on both LoRA and Canonical LoRA with target modules mentioned in the PR description and it looks fine, part of the result is now in the PR description.

For expert layers LoRA, current implementation is that we share an adapter per EP rank, and one of my previous PR #1817 should have already made sure that conversion to HF adapters also works.

Later we also might want to introduce unfused in_proj support for Canonical LoRA (in_proj_a, in_proj_b, in_proj_qkv, in_proj_z)

@cuichenx
Copy link
Contributor

Hi @cuichenx ! Thank you for leaving comments. I just add more fixes 4928162 & 7d27a3e and this PR should be ready for further review. I have runned Megatron-Bridge/examples/conversion/stream_adapter_weights.py with modification to test on both LoRA and Canonical LoRA with target modules mentioned in the PR description and it looks fine, part of the result is now in the PR description.

For expert layers LoRA, current implementation is that we share an adapter per EP rank, and one of my previous PR #1817 should have already made sure that conversion to HF adapters also works.

Later we also might want to introduce unfused in_proj support for Canonical LoRA (in_proj_a, in_proj_b, in_proj_qkv, in_proj_z)

Thanks, I see the results in stream_adapter_weights.py.

Were you able to verify if running HF inference with the converted HF adapter results in the expected output? My main concern is that HF has a different adapter forward pass implementation from mbridge in the fused expert layer.

@HollowMan6
Copy link
Contributor Author

Were you able to verify if running HF inference with the converted HF adapter results in the expected output? My main concern is that HF has a different adapter forward pass implementation from mbridge in the fused expert layer.

No, my current focus is on the RL LoRA side, where we use vLLM for inference, so I haven't run HF inference directly (it would also be too slow for practical use), and vLLM's Qwen3.5 LoRA support is still a work in progress:

However, vLLM has been using the fused MoE kernel for LoRA for quite some time (vllm-project/vllm@5f6cbf6). I previously ran convergence tests with a long RL run on Qwen3 30B A3B MoE, documented in #1817 (comment), and the training–inference mismatch remained at a very low level. Since this PR doesn't modify the adapter forward pass implementation, I believe that conclusion still holds and things should still be fine.

@cuichenx
Copy link
Contributor

Sounds good, I'm convinced

@cuichenx
Copy link
Contributor

/ok to test 247a6dd

@HollowMan6
Copy link
Contributor Author

Looks like there's a linting error that's unrelated to this PR (https://github.com/NVIDIA-NeMo/Megatron-Bridge/actions/runs/23172290078/job/67326715440?pr=2736) and has been fixed by f017aa8. I just updated the PR branch, would you mind retriggering the CI @cuichenx, thanks!

@cuichenx
Copy link
Contributor

/ok to test d7d7c49

@cuichenx
Copy link
Contributor

/ok to test 9ffc10a

cuichenx
cuichenx previously approved these changes Mar 19, 2026
@cuichenx cuichenx added ready-to-merge PR is approved, current, and only waiting for CI to pass before merge and removed needs-review PR is ready for code review and waiting on a reviewer labels Mar 19, 2026
- handle HF base names without .weight for LoRA suffixing
- add GDN in-proj split logic for fused adapters
- support packed expert LoRA stacking in streaming
- fix confusion with mamba layers when saving checkpoints

Signed-off-by: Hollow Man <hollowman@opensuse.org>
- required_world_size calculation logic
- logic for handling `base_layer` when not ending with `weight`
- Some linear_fc1 modules do not map to separate gate/up HF weights for Canonical LoRA

Signed-off-by: Hollow Man <hollowman@opensuse.org>
…unfused

Signed-off-by: Hollow Man <hollowman@opensuse.org>
@yaoyu-33
Copy link
Contributor

/ok to test 7c71b63

@cuichenx cuichenx merged commit e049cc0 into NVIDIA-NeMo:main Mar 20, 2026
39 of 41 checks passed
@HollowMan6 HollowMan6 deleted the qwen3.5 branch March 21, 2026 15:05
liding-nv pushed a commit that referenced this pull request Mar 22, 2026
Signed-off-by: Hollow Man <hollowman@opensuse.org>
Signed-off-by: Li Ding <liding@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:peft Parameter-efficient fine-tuning (LoRA, adapters) community-request ready-to-merge PR is approved, current, and only waiting for CI to pass before merge

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants