Skip to content

fix: generalize LoRA layer handling for N-way fused projections#37019

Open
hallerite wants to merge 6 commits intovllm-project:mainfrom
hallerite:fix/generalize-sharded-lora-slice-a
Open

fix: generalize LoRA layer handling for N-way fused projections#37019
hallerite wants to merge 6 commits intovllm-project:mainfrom
hallerite:fix/generalize-sharded-lora-slice-a

Conversation

@hallerite
Copy link
Copy Markdown
Contributor

@hallerite hallerite commented Mar 13, 2026

Summary

Fixes LoRA support for models with non-standard N-way fused projections (e.g., Qwen3.5 GDN layers that fuse 4 projections into a single MergedColumnParallelLinear).

Three changes in column_parallel_linear.py:

  1. MergedColumnParallelLinearWithLoRA.can_replace_layer: changed from len(packed_modules_list) == 2 to len(packed_modules_list) == len(source_layer.output_sizes) — so it correctly handles any N-way merged column parallel linear, not just 2-way.

  2. MergedColumnParallelLinearVariableSliceWithLoRA.set_lora: added greedy expansion logic for when a LoRA weight list is shorter than n_slices. This handles the case where M packed modules map to N output sizes (M < N) — each lora_b entry is greedily matched to consecutive output_sizes by dimension, then split accordingly.

  3. slice_lora_a generalization: both MergedColumnParallelLinearWithShardedLoRA (was hardcoded for 2 subloras) and MergedQKVParallelLinearWithShardedLoRA (was hardcoded for 3) now use a list comprehension over the actual lora_a inputs, supporting any N. This fixes fully_sharded_loras=True for models with >2 or >3 fused projections.

These are LoRA infrastructure fixes — no model code changes required. A related PR #36976 solves a similar problem for Qwen3.5 specifically via model-side changes; this PR instead fixes it generically at the LoRA layer so any future model with a similar fusion pattern gets the fix for free.

Related: #36372, #36478, #36976

Test plan

Tested on 2x RTX PRO 6000 Blackwell with LoRA adapters targeting GDN layers (in_proj_qkv, in_proj_z):

Qwen/Qwen3.5-9B (dense)

  • GDN LoRA, TP=1, fully_sharded_loras=False — PASS
  • GDN LoRA, TP=1, fully_sharded_loras=True — PASS
  • GDN LoRA, TP=2, fully_sharded_loras=False — PASS
  • Standard LoRA (attention/MLP only), TP=1, fully_sharded_loras=False — PASS (regression)
  • Standard LoRA (attention/MLP only), TP=1, fully_sharded_loras=True — PASS (regression)

Qwen/Qwen3.5-35B-A3B (MoE)

  • GDN LoRA, TP=2, fully_sharded_loras=False — PASS
  • GDN LoRA, TP=2, fully_sharded_loras=True — PASS

Unit tests

  • test_column_parallel_packed — all 36 parametrized cases pass (fully_shard=True/False × repeats=1,2,3 × num_loras=1,2,4)

🤖 Generated with Claude Code

hallerite and others added 4 commits March 13, 2026 21:17
Signed-off-by: hallerite <git@hallerite.com>
Signed-off-by: hallerite <git@hallerite.com>
Revert packed_modules_mapping to real HF weight names (in_proj_qkv,
in_proj_z) to fix bitsandbytes quant state stacking, and extend
MergedColumnParallelLinearVariableSliceWithLoRA to handle the mismatch
between packed module count (2) and output_sizes count (4) in GDN layers.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hallerite <git@hallerite.com>
Signed-off-by: hallerite <git@hallerite.com>
@hallerite hallerite force-pushed the fix/generalize-sharded-lora-slice-a branch from 5ce20be to 67bf8ec Compare March 13, 2026 23:18
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request generalizes the LoRA slicing logic for sharded layers to support an arbitrary number of subloras, removing hardcoded limits. It also enhances the logic for non-sharded variable-slice layers to handle complex cases where the number of packed LoRA modules doesn't match the number of output slices, as seen in models like Qwen3.5. My review focuses on the new logic for handling these complex cases. I've identified a potential issue where a dimension mismatch could lead to a runtime error and suggested adding a validation check.

while slice_idx < self.n_slices and consumed < b_dim:
consumed += output_sizes[slice_idx]
slice_idx += 1
num_covered = slice_idx - start_slice
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The greedy consumption of output_sizes to match b_dim is a good approach. However, the loop condition consumed < b_dim can result in consumed > b_dim if the output_sizes don't sum up exactly to b_dim. This would lead to an IndexError in the subsequent slicing of b_i.

To prevent this, it's safer to assert that consumed is exactly equal to b_dim after the loop. This will ensure the dimensions are valid and provide a more informative error message if they are not.

if consumed != b_dim:
    raise ValueError(
        f"Packed LoRA B dimension {b_dim} does not match "
        f"the sum of output sizes {consumed} for LoRA {i}."
    )
num_covered = slice_idx - start_slice

hallerite and others added 2 commits March 13, 2026 23:23
Assert that consumed dimensions exactly match lora_b's shape after
greedily matching output_sizes. Prevents silent data corruption if
dimensions don't align.

Signed-off-by: hallerite <git@hallerite.com>
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
MergedColumnParallelLinearWithShardedLoRA.slice_lora_a was hardcoded
for exactly 2 subloras, and MergedQKVParallelLinearWithShardedLoRA
for exactly 3. Both now use a list comprehension over the actual
lora_a inputs, supporting any N (needed for Qwen3.5's 4-way qkvz
fusion with fully_sharded_loras=True).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: hallerite <git@hallerite.com>
@hallerite hallerite force-pushed the fix/generalize-sharded-lora-slice-a branch from 67bf8ec to 6bbd678 Compare March 13, 2026 23:23
Copy link
Copy Markdown
Contributor

@alvinttang alvinttang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review: Generalize Sharded LoRA slice_lora_a for N Subloras

Good refactor that eliminates hardcoded 2-sublora and 3-sublora assumptions. The generalized list comprehensions are cleaner and will handle models like Qwen3.5 GDN that have non-standard packed module counts.

Concerns

1. can_replace_layer change in MergedColumnParallelLinearWithLoRA may break existing models

# Before:
len(packed_modules_list) == 2
# After:
len(packed_modules_list) == len(source_layer.output_sizes)

This changes the dispatch logic for which LoRA class handles a given layer. If any existing model has packed_modules_list length != len(output_sizes) (which is exactly the Qwen3.5 case you describe), this layer will now be handled by MergedQKVParallelLinearWithLoRA instead of MergedColumnParallelLinearWithLoRA. But MergedQKVParallelLinearWithLoRA.can_replace_layer also changed to accept these mismatches. Could both classes now claim to handle the same layer type? What is the priority/ordering of can_replace_layer checks?

2. The greedy matching in set_lora is fragile

while slice_idx < self.n_slices and consumed < b_dim:
    consumed += output_sizes[slice_idx]
    slice_idx += 1
if consumed != b_dim:
    raise ValueError(...)

This greedy approach assumes output_sizes are always positive and that there is a unique partition of consecutive output_sizes that sums to each b_dim. If two different partitions could produce the same sum (e.g., output_sizes [128, 128, 256] could match b_dim=256 as either [128+128] or [256]), the greedy left-to-right matching would pick the wrong one. Is this guaranteed by the model architecture? A comment explaining why the greedy approach is always correct would be valuable.

3. None handling for missing subloras

if b_i is None:
    remaining = len(lora_a) - i - 1
    remaining_slices = self.n_slices - slice_idx
    count = remaining_slices - remaining

This assumes that every remaining non-None entry will cover exactly 1 slice, which may not hold if multiple entries each cover multiple slices. The count calculation remaining_slices - remaining is only correct under that assumption. If a later entry also covers multiple slices, you will over-count Nones here.

4. slice_lora_a in MergedQKVParallelLinearWithShardedLoRA

return [
    a[self.tp_rank * s.shape[2] : (self.tp_rank + 1) * s.shape[2], :]
    if a is not None
    else None
    for a, s in zip(lora_a, self.lora_a_stacked)
]

This is cleaner than the original. One edge case: if lora_a and self.lora_a_stacked have different lengths (e.g., before the expansion in set_lora runs), zip silently truncates. Consider using zip(..., strict=True) (Python 3.10+) or at least asserting equal lengths.

5. Missing test coverage
This is a non-trivial refactor of LoRA weight loading and sharding logic. There are no tests in this PR. At minimum:

  • A test for slice_lora_a with N > 3 subloras
  • A test for set_lora with the expansion path (list shorter than n_slices)
  • A test for the None sublora case during expansion

6. The trailing while len(expanded_a) < self.n_slices pad
This silently pads with None if the expansion logic doesn't produce enough slices. This could mask bugs where the greedy matching fails to consume all output_sizes. Consider raising an error instead, or at least logging a warning.

Minor

  • The can_replace_layer change removing len(packed_modules_list) == 2 hardcode is the right direction.
  • The list comprehension refactors in slice_lora_a are cleaner and more maintainable.

Overall: the generalization direction is correct, but the greedy expansion logic needs more careful analysis of edge cases and test coverage.

@hallerite hallerite marked this pull request as ready for review March 15, 2026 22:23
@hallerite hallerite requested a review from jeejeelee as a code owner March 15, 2026 22:23
@hallerite hallerite changed the title fix: generalize sharded LoRA slice_lora_a for N subloras fix: generalize LoRA layer handling for N-way fused projections Mar 15, 2026
@devlup
Copy link
Copy Markdown

devlup commented Mar 16, 2026

@jeejeelee and @hallerite this is good, when do you think will be able to merge it?

@devlup
Copy link
Copy Markdown

devlup commented Mar 24, 2026

@jeejeelee and @hallerite can we merge it? it's working

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants