fix: generalize LoRA layer handling for N-way fused projections#37019
fix: generalize LoRA layer handling for N-way fused projections#37019hallerite wants to merge 6 commits intovllm-project:mainfrom
Conversation
Signed-off-by: hallerite <git@hallerite.com>
Signed-off-by: hallerite <git@hallerite.com>
Revert packed_modules_mapping to real HF weight names (in_proj_qkv, in_proj_z) to fix bitsandbytes quant state stacking, and extend MergedColumnParallelLinearVariableSliceWithLoRA to handle the mismatch between packed module count (2) and output_sizes count (4) in GDN layers. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: hallerite <git@hallerite.com>
Signed-off-by: hallerite <git@hallerite.com>
5ce20be to
67bf8ec
Compare
There was a problem hiding this comment.
Code Review
This pull request generalizes the LoRA slicing logic for sharded layers to support an arbitrary number of subloras, removing hardcoded limits. It also enhances the logic for non-sharded variable-slice layers to handle complex cases where the number of packed LoRA modules doesn't match the number of output slices, as seen in models like Qwen3.5. My review focuses on the new logic for handling these complex cases. I've identified a potential issue where a dimension mismatch could lead to a runtime error and suggested adding a validation check.
| while slice_idx < self.n_slices and consumed < b_dim: | ||
| consumed += output_sizes[slice_idx] | ||
| slice_idx += 1 | ||
| num_covered = slice_idx - start_slice |
There was a problem hiding this comment.
The greedy consumption of output_sizes to match b_dim is a good approach. However, the loop condition consumed < b_dim can result in consumed > b_dim if the output_sizes don't sum up exactly to b_dim. This would lead to an IndexError in the subsequent slicing of b_i.
To prevent this, it's safer to assert that consumed is exactly equal to b_dim after the loop. This will ensure the dimensions are valid and provide a more informative error message if they are not.
if consumed != b_dim:
raise ValueError(
f"Packed LoRA B dimension {b_dim} does not match "
f"the sum of output sizes {consumed} for LoRA {i}."
)
num_covered = slice_idx - start_sliceAssert that consumed dimensions exactly match lora_b's shape after greedily matching output_sizes. Prevents silent data corruption if dimensions don't align. Signed-off-by: hallerite <git@hallerite.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
MergedColumnParallelLinearWithShardedLoRA.slice_lora_a was hardcoded for exactly 2 subloras, and MergedQKVParallelLinearWithShardedLoRA for exactly 3. Both now use a list comprehension over the actual lora_a inputs, supporting any N (needed for Qwen3.5's 4-way qkvz fusion with fully_sharded_loras=True). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com> Signed-off-by: hallerite <git@hallerite.com>
67bf8ec to
6bbd678
Compare
alvinttang
left a comment
There was a problem hiding this comment.
Code Review: Generalize Sharded LoRA slice_lora_a for N Subloras
Good refactor that eliminates hardcoded 2-sublora and 3-sublora assumptions. The generalized list comprehensions are cleaner and will handle models like Qwen3.5 GDN that have non-standard packed module counts.
Concerns
1. can_replace_layer change in MergedColumnParallelLinearWithLoRA may break existing models
# Before:
len(packed_modules_list) == 2
# After:
len(packed_modules_list) == len(source_layer.output_sizes)This changes the dispatch logic for which LoRA class handles a given layer. If any existing model has packed_modules_list length != len(output_sizes) (which is exactly the Qwen3.5 case you describe), this layer will now be handled by MergedQKVParallelLinearWithLoRA instead of MergedColumnParallelLinearWithLoRA. But MergedQKVParallelLinearWithLoRA.can_replace_layer also changed to accept these mismatches. Could both classes now claim to handle the same layer type? What is the priority/ordering of can_replace_layer checks?
2. The greedy matching in set_lora is fragile
while slice_idx < self.n_slices and consumed < b_dim:
consumed += output_sizes[slice_idx]
slice_idx += 1
if consumed != b_dim:
raise ValueError(...)This greedy approach assumes output_sizes are always positive and that there is a unique partition of consecutive output_sizes that sums to each b_dim. If two different partitions could produce the same sum (e.g., output_sizes [128, 128, 256] could match b_dim=256 as either [128+128] or [256]), the greedy left-to-right matching would pick the wrong one. Is this guaranteed by the model architecture? A comment explaining why the greedy approach is always correct would be valuable.
3. None handling for missing subloras
if b_i is None:
remaining = len(lora_a) - i - 1
remaining_slices = self.n_slices - slice_idx
count = remaining_slices - remainingThis assumes that every remaining non-None entry will cover exactly 1 slice, which may not hold if multiple entries each cover multiple slices. The count calculation remaining_slices - remaining is only correct under that assumption. If a later entry also covers multiple slices, you will over-count Nones here.
4. slice_lora_a in MergedQKVParallelLinearWithShardedLoRA
return [
a[self.tp_rank * s.shape[2] : (self.tp_rank + 1) * s.shape[2], :]
if a is not None
else None
for a, s in zip(lora_a, self.lora_a_stacked)
]This is cleaner than the original. One edge case: if lora_a and self.lora_a_stacked have different lengths (e.g., before the expansion in set_lora runs), zip silently truncates. Consider using zip(..., strict=True) (Python 3.10+) or at least asserting equal lengths.
5. Missing test coverage
This is a non-trivial refactor of LoRA weight loading and sharding logic. There are no tests in this PR. At minimum:
- A test for
slice_lora_awith N > 3 subloras - A test for
set_lorawith the expansion path (list shorter than n_slices) - A test for the
Nonesublora case during expansion
6. The trailing while len(expanded_a) < self.n_slices pad
This silently pads with None if the expansion logic doesn't produce enough slices. This could mask bugs where the greedy matching fails to consume all output_sizes. Consider raising an error instead, or at least logging a warning.
Minor
- The
can_replace_layerchange removinglen(packed_modules_list) == 2hardcode is the right direction. - The list comprehension refactors in
slice_lora_aare cleaner and more maintainable.
Overall: the generalization direction is correct, but the greedy expansion logic needs more careful analysis of edge cases and test coverage.
|
@jeejeelee and @hallerite this is good, when do you think will be able to merge it? |
|
@jeejeelee and @hallerite can we merge it? it's working |
Summary
Fixes LoRA support for models with non-standard N-way fused projections (e.g., Qwen3.5 GDN layers that fuse 4 projections into a single
MergedColumnParallelLinear).Three changes in
column_parallel_linear.py:MergedColumnParallelLinearWithLoRA.can_replace_layer: changed fromlen(packed_modules_list) == 2tolen(packed_modules_list) == len(source_layer.output_sizes)— so it correctly handles any N-way merged column parallel linear, not just 2-way.MergedColumnParallelLinearVariableSliceWithLoRA.set_lora: added greedy expansion logic for when a LoRA weight list is shorter thann_slices. This handles the case where M packed modules map to N output sizes (M < N) — eachlora_bentry is greedily matched to consecutiveoutput_sizesby dimension, then split accordingly.slice_lora_ageneralization: bothMergedColumnParallelLinearWithShardedLoRA(was hardcoded for 2 subloras) andMergedQKVParallelLinearWithShardedLoRA(was hardcoded for 3) now use a list comprehension over the actuallora_ainputs, supporting any N. This fixesfully_sharded_loras=Truefor models with >2 or >3 fused projections.These are LoRA infrastructure fixes — no model code changes required. A related PR #36976 solves a similar problem for Qwen3.5 specifically via model-side changes; this PR instead fixes it generically at the LoRA layer so any future model with a similar fusion pattern gets the fix for free.
Related: #36372, #36478, #36976
Test plan
Tested on 2x RTX PRO 6000 Blackwell with LoRA adapters targeting GDN layers (
in_proj_qkv,in_proj_z):Qwen/Qwen3.5-9B (dense)
fully_sharded_loras=False— PASSfully_sharded_loras=True— PASSfully_sharded_loras=False— PASSfully_sharded_loras=False— PASS (regression)fully_sharded_loras=True— PASS (regression)Qwen/Qwen3.5-35B-A3B (MoE)
fully_sharded_loras=False— PASSfully_sharded_loras=True— PASSUnit tests
test_column_parallel_packed— all 36 parametrized cases pass (fully_shard=True/False×repeats=1,2,3×num_loras=1,2,4)🤖 Generated with Claude Code