From 689d5636cdfcbc39ae791b2d816047e1cffa1a64 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 23 Mar 2026 02:11:12 +0800 Subject: [PATCH 1/4] draft Signed-off-by: Isotr0py --- vllm/lora/layers/column_parallel_linear.py | 64 ++++++++++--- vllm/lora/model_manager.py | 37 ++++++-- vllm/model_executor/models/qwen3_5.py | 101 +++------------------ vllm/model_executor/models/qwen3_next.py | 16 ++-- 4 files changed, 101 insertions(+), 117 deletions(-) diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index f49a3fcbb941..80b83305912c 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -187,9 +187,9 @@ def __init__( # There are two LoRA layers # the output_sizes in MergedColumnParallelLinear is not sharded by tp # we need to divide it by the tp_size to get correct slices size - output_sizes = self.base_layer.output_sizes + self.output_sizes = self.base_layer.output_sizes self.output_slices = tuple( - divide(output_size, self.tp_size) for output_size in output_sizes + divide(output_size, self.tp_size) for output_size in self.output_sizes ) self.n_slices = len(self.output_slices) self.output_ids = (self.tp_rank,) * self.n_slices @@ -253,6 +253,46 @@ def slice_lora_b( ] return sliced_lora_b + def _expand_packed_lora( + self, + lora_a: list[torch.Tensor | None], + lora_b: list[torch.Tensor | None], + ) -> tuple[list[torch.Tensor | None], list[torch.Tensor | None]]: + """Expand packed adapter groups to match n_slices. + + Some adapters store weights for multiple consecutive output slices as a + single fused tensor (e.g., a single ``in_proj_qkv`` tensor covering + Q, K and V slices of a 4-slice layer). This method splits each + lora_b entry according to the layer's ``output_sizes`` and replicates + the corresponding lora_a for every slice it covers. + """ + output_sizes = self.base_layer.output_sizes + expanded_a: list[torch.Tensor | None] = [] + expanded_b: list[torch.Tensor | None] = [] + slice_idx = 0 + for a_i, b_i in zip(lora_a, lora_b): + if b_i is None: + expanded_a.append(None) + expanded_b.append(None) + slice_idx += 1 + continue + # Determine how many output slices this b_i covers. + b_rows = b_i.shape[0] + covered = 0 + cumulative = 0 + while slice_idx + covered < len(output_sizes) and cumulative < b_rows: + cumulative += output_sizes[slice_idx + covered] + covered += 1 + # Split b_i into per-slice tensors and replicate a_i for each. + start = 0 + for j in range(covered): + size = output_sizes[slice_idx + j] + expanded_b.append(b_i[start : start + size, :]) + expanded_a.append(a_i) + start += size + slice_idx += covered + return expanded_a, expanded_b + def set_lora( self, index: int, @@ -261,6 +301,12 @@ def set_lora( ): self.reset_lora(index) + # Expand packed adapter groups when they don't match n_slices. + # E.g. in_proj_qkv (covers Q+K+V) + in_proj_z as 2 groups for a + # 4-slice layer: split b_qkv by output_sizes and replicate a_qkv. + if isinstance(lora_b, list) and len(lora_b) != self.n_slices: + lora_a, lora_b = self._expand_packed_lora(lora_a, lora_b) + if self.tp_size > 1: lora_a = self.slice_lora_a(lora_a) lora_b = self.slice_lora_b(lora_b) @@ -467,18 +513,14 @@ class MergedColumnParallelLinearWithShardedLoRA(MergedColumnParallelLinearWithLo def slice_lora_a( self, lora_a: list[torch.Tensor | None] ) -> list[torch.Tensor | None]: - # NOTE: lora_a contains 2 subloras, and each sublora could be None. output_shard_size = self.lora_a_stacked[0].shape[2] output_start_idx = self.tp_rank * output_shard_size - lora_a = [ - lora_a[0][output_start_idx : output_start_idx + output_shard_size, :] - if lora_a[0] is not None - else None, - lora_a[1][output_start_idx : output_start_idx + output_shard_size, :] - if lora_a[1] is not None - else None, + return [ + lora_a[i][output_start_idx : output_start_idx + output_shard_size, :] + if lora_a[i] is not None + else None + for i in range(len(lora_a)) ] - return lora_a def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: return _mcp_apply(x, bias, self) diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index 9d3772560433..88a14b6458a0 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -554,17 +554,34 @@ def create_dummy_lora( else: parts = module_name.split(".") replacements = self.packed_modules_mapping[parts[-1]] + n_slices = getattr(module, "n_slices", len(replacements)) subloras: list[LoRALayerWeights | None] = [] - for i, r in enumerate(replacements): - lora = LoRALayerWeights.create_dummy_lora_weights( - module_name + "." + r, - module.lora_a_stacked[i].shape[-1], - module.lora_b_stacked[i].shape[-2], - rank, - module.lora_a_stacked[i].dtype, - "cpu", - ) - subloras.append(lora) + if n_slices != len(replacements): + # When a packed module has more slices than replacements + # (e.g. in_proj_qkvz has 4 slices but only 2 replacements), + # create one dummy sublora per slice so that set_lora + # receives len(lora_b) == n_slices without expansion. + for i in range(n_slices): + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name + f".slice_{i}", + module.lora_a_stacked[i].shape[-1], + module.lora_b_stacked[i].shape[-2], + rank, + module.lora_a_stacked[i].dtype, + "cpu", + ) + subloras.append(lora) + else: + for i, r in enumerate(replacements): + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name + "." + r, + module.lora_a_stacked[i].shape[-1], + module.lora_b_stacked[i].shape[-2], + rank, + module.lora_a_stacked[i].dtype, + "cpu", + ) + subloras.append(lora) if module.__class__.__name__ == "FusedMoEWithLoRA": # For non-gated MoE, pad subloras to 3 elements per expert # to match pack_moe expectations (w1, w2, None for w3) diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py index daca52821e0f..ee96335df42c 100644 --- a/vllm/model_executor/models/qwen3_5.py +++ b/vllm/model_executor/models/qwen3_5.py @@ -41,7 +41,6 @@ GemmaRMSNorm as Qwen3_5RMSNorm, ) from vllm.model_executor.layers.linear import ( - ColumnParallelLinear, MergedColumnParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -131,40 +130,6 @@ def fix_query_key_value_ordering( "Qwen3.5 Series dont need to fix query key value ordering" ) - def __init__( - self, - config: Qwen3_5Config, - vllm_config: VllmConfig, - prefix: str = "", - ) -> None: - create_in_proj_qkvz = vllm_config.lora_config is None - super().__init__( - config, - vllm_config=vllm_config, - prefix=prefix, - create_in_proj_qkvz=create_in_proj_qkvz, - ) - if vllm_config.lora_config is not None: - # Separate in_proj_qkv (Q,K,V) and in_proj_z for LoRA compatibility. - # Use MergedColumnParallelLinear for in_proj_qkv because GDN can have - # linear_num_key_heads != linear_num_value_heads (e.g. 16 vs 32), so - # output sizes [key_dim, key_dim, value_dim] are not representable - # with a single QKVParallelLinear (which ties K and V head counts). - self.in_proj_qkv = MergedColumnParallelLinear( - input_size=self.hidden_size, - output_sizes=[self.key_dim, self.key_dim, self.value_dim], - bias=False, - quant_config=vllm_config.quant_config, - prefix=f"{prefix}.in_proj_qkv", - ) - self.in_proj_z = ColumnParallelLinear( - input_size=self.hidden_size, - output_size=self.value_dim, - bias=False, - quant_config=vllm_config.quant_config, - prefix=f"{prefix}.in_proj_z", - ) - def create_qkvz_proj( self, hidden_size: int, @@ -215,21 +180,15 @@ def forward( # ============================================================ # Part 1: Input Projection # ============================================================ - if hasattr(self, "in_proj_qkv"): - # LoRA path: separate in_proj_qkv and in_proj_z - mixed_qkv, _ = self.in_proj_qkv(hidden_states) - ba, _ = self.in_proj_ba(hidden_states) - z, _ = self.in_proj_z(hidden_states) - else: - mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj( - hidden_states, - sum(self.in_proj_qkvz.output_sizes) // self.tp_size, - sum(self.in_proj_ba.output_sizes) // self.tp_size, - self.prefix, - ) - qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size - z_size = self.value_dim // self.tp_size - mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1) + mixed_qkvz, ba = torch.ops.vllm.gdn_in_proj( + hidden_states, + sum(self.in_proj_qkvz.output_sizes) // self.tp_size, + sum(self.in_proj_ba.output_sizes) // self.tp_size, + self.prefix, + ) + qkv_size = (self.key_dim * 2 + self.value_dim) // self.tp_size + z_size = self.value_dim // self.tp_size + mixed_qkv, z = mixed_qkvz.split([qkv_size, z_size], dim=-1) z = z.reshape(z.size(0), -1, self.head_v_dim) b, a = ba.chunk(2, dim=-1) @@ -368,7 +327,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_redundant_experts = eplb_config.num_redundant_experts self.config = config - self.enable_lora = vllm_config.lora_config is not None self.vocab_size = config.vocab_size @@ -427,6 +385,9 @@ def load_fused_expert_weights( def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) + # GDN + ("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)), + ("in_proj_qkvz", "in_proj_z", 3), # self attention ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), @@ -438,21 +399,6 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: ("in_proj_ba", "in_proj_a", 1), ] - if self.enable_lora: - stacked_params_mapping.extend( - [ - ("in_proj_qkv", "in_proj_qkv", (0, 1, 2)), - ("in_proj_z", "in_proj_z", 0), - ] - ) - else: - stacked_params_mapping.extend( - [ - ("in_proj_qkvz", "in_proj_qkv", (0, 1, 2)), - ("in_proj_qkvz", "in_proj_z", 3), - ] - ) - params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() @@ -500,10 +446,7 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: continue param = params_dict[name] weight_loader = param.weight_loader - if param_name == "in_proj_z" and self.enable_lora: - weight_loader(param, loaded_weight) - else: - weight_loader(param, loaded_weight, shard_id) + weight_loader(param, loaded_weight, shard_id) break else: is_expert_weight = False @@ -633,15 +576,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) - # When LoRA is enabled, GDN uses separate in_proj_qkv and in_proj_z - # instead of merged in_proj_qkvz; pack mapping must match. - if vllm_config.lora_config: - base = getattr(Qwen3_5ForCausalLMBase, "packed_modules_mapping", {}) - self.packed_modules_mapping = {k: list(v) for k, v in base.items()} - self.packed_modules_mapping.pop("in_proj_qkvz", None) - self.packed_modules_mapping["in_proj_qkv"] = ["in_proj_qkv"] - self.packed_modules_mapping["in_proj_z"] = ["in_proj_z"] - if get_pp_group().is_last_rank: if config.tie_word_embeddings: self.lm_head = self.model.embed_tokens @@ -763,14 +697,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): ) def update_packed_mapping(self, enable_lora: bool): - # When LoRA is enabled, GDN uses separate in_proj_qkv and in_proj_z - if enable_lora: - base = getattr( - Qwen3_5ForConditionalGeneration, "packed_modules_mapping", {} - ) - self.packed_modules_mapping = {k: list(v) for k, v in base.items()} - self.packed_modules_mapping.pop("in_proj_qkvz", None) - self.packed_modules_mapping["in_proj_qkv"] = ["in_proj_qkv"] + pass def embed_input_ids( self, diff --git a/vllm/model_executor/models/qwen3_next.py b/vllm/model_executor/models/qwen3_next.py index c972570532f6..e77a0ef8bc64 100644 --- a/vllm/model_executor/models/qwen3_next.py +++ b/vllm/model_executor/models/qwen3_next.py @@ -402,7 +402,6 @@ def __init__( config: Qwen3NextConfig, vllm_config: VllmConfig, prefix: str = "", - create_in_proj_qkvz: bool = True, ) -> None: super().__init__() self.tp_size = get_tensor_model_parallel_world_size() @@ -454,14 +453,13 @@ def __init__( # we need to create qkvz_proj adaptively here. # When create_in_proj_qkvz is False (e.g. LoRA enabled in Qwen3.5), # the subclass creates in_proj_qkv and in_proj_z separately. - if create_in_proj_qkvz: - self.in_proj_qkvz = self.create_qkvz_proj( - hidden_size=self.hidden_size, - key_dim=self.key_dim, - value_dim=self.value_dim, - quant_config=quant_config, - prefix=f"{prefix}.in_proj_qkvz", - ) + self.in_proj_qkvz = self.create_qkvz_proj( + hidden_size=self.hidden_size, + key_dim=self.key_dim, + value_dim=self.value_dim, + quant_config=quant_config, + prefix=f"{prefix}.in_proj_qkvz", + ) # ba_proj doesn't support blockwise fp8 quantization. # Qwen3-Next and Qwen3.5 have different in_proj_ba checkpoint # layouts, so we use a factory method to create the projection. From 26d556687d5a95baa89767a313049bbd93b33f62 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 24 Mar 2026 01:37:06 +0800 Subject: [PATCH 2/4] clean Signed-off-by: Isotr0py --- vllm/lora/layers/column_parallel_linear.py | 63 +++++++++++----------- vllm/model_executor/models/qwen3_5.py | 5 -- 2 files changed, 31 insertions(+), 37 deletions(-) diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index 80b83305912c..cc1ae17cf478 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -253,44 +253,43 @@ def slice_lora_b( ] return sliced_lora_b - def _expand_packed_lora( + def expand_packed_lora( self, - lora_a: list[torch.Tensor | None], - lora_b: list[torch.Tensor | None], - ) -> tuple[list[torch.Tensor | None], list[torch.Tensor | None]]: - """Expand packed adapter groups to match n_slices. - - Some adapters store weights for multiple consecutive output slices as a - single fused tensor (e.g., a single ``in_proj_qkv`` tensor covering - Q, K and V slices of a 4-slice layer). This method splits each - lora_b entry according to the layer's ``output_sizes`` and replicates - the corresponding lora_a for every slice it covers. + lora_a: list[torch.Tensor], + lora_b: list[torch.Tensor], + ) -> tuple[list[torch.Tensor], list[torch.Tensor]]: """ - output_sizes = self.base_layer.output_sizes - expanded_a: list[torch.Tensor | None] = [] - expanded_b: list[torch.Tensor | None] = [] - slice_idx = 0 + Expand packed adapter groups when they don't match n_slices. + E.g. in_proj_qkv (covers Q+K+V) + in_proj_z + """ + # FIXME(Isotr0py): Currently, we assume multiple slices are always + # like qkv in qkvz (start from 0). We need to think about what if + # slices don't start from 0 in the future. + expanded_a: list[torch.Tensor] = [] + expanded_b: list[torch.Tensor] = [] + start_idx = 0 for a_i, b_i in zip(lora_a, lora_b): - if b_i is None: - expanded_a.append(None) - expanded_b.append(None) - slice_idx += 1 - continue - # Determine how many output slices this b_i covers. - b_rows = b_i.shape[0] - covered = 0 - cumulative = 0 - while slice_idx + covered < len(output_sizes) and cumulative < b_rows: - cumulative += output_sizes[slice_idx + covered] - covered += 1 + # Determine which output slices this b_i covers. + b_rows, cu_rows, covered = b_i.shape[0], 0, 0 + for i in range(start_idx, self.n_slices): + cu_rows += self.output_sizes[i] + if cu_rows == b_rows: + covered = i - start_idx + 1 + break + else: + raise ValueError( + f"Cannot determine how to split lora_b with {b_rows} rows " + f"into {self.n_slices} slices with output sizes " + f"{self.output_sizes} starting from index {start_idx}." + ) # Split b_i into per-slice tensors and replicate a_i for each. start = 0 for j in range(covered): - size = output_sizes[slice_idx + j] + size = self.output_sizes[start_idx + j] expanded_b.append(b_i[start : start + size, :]) expanded_a.append(a_i) start += size - slice_idx += covered + start_idx += covered return expanded_a, expanded_b def set_lora( @@ -305,7 +304,7 @@ def set_lora( # E.g. in_proj_qkv (covers Q+K+V) + in_proj_z as 2 groups for a # 4-slice layer: split b_qkv by output_sizes and replicate a_qkv. if isinstance(lora_b, list) and len(lora_b) != self.n_slices: - lora_a, lora_b = self._expand_packed_lora(lora_a, lora_b) + lora_a, lora_b = self.expand_packed_lora(lora_a, lora_b) if self.tp_size > 1: lora_a = self.slice_lora_a(lora_a) @@ -516,8 +515,8 @@ def slice_lora_a( output_shard_size = self.lora_a_stacked[0].shape[2] output_start_idx = self.tp_rank * output_shard_size return [ - lora_a[i][output_start_idx : output_start_idx + output_shard_size, :] - if lora_a[i] is not None + lora_a_i[output_start_idx : output_start_idx + output_shard_size, :] + if (lora_a_i := lora_a[i]) is not None else None for i in range(len(lora_a)) ] diff --git a/vllm/model_executor/models/qwen3_5.py b/vllm/model_executor/models/qwen3_5.py index ee96335df42c..e20a19825623 100644 --- a/vllm/model_executor/models/qwen3_5.py +++ b/vllm/model_executor/models/qwen3_5.py @@ -668,7 +668,6 @@ class Qwen3_5ForConditionalGeneration(Qwen3VLForConditionalGeneration, IsHybrid) def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): # protocols have not __init__ method, so we need to use nn.Module.__init__ nn.Module.__init__(self) - self.update_packed_mapping(enable_lora=vllm_config.lora_config is not None) config: Qwen3_5Config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config @@ -696,9 +695,6 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): self.language_model.make_empty_intermediate_tensors ) - def update_packed_mapping(self, enable_lora: bool): - pass - def embed_input_ids( self, input_ids: torch.Tensor, @@ -885,7 +881,6 @@ class Qwen3_5MoeForConditionalGeneration( def __init__(self, *, vllm_config: VllmConfig, prefix: str = "model"): # protocols have not __init__ method, so we need to use nn.Module.__init__ nn.Module.__init__(self) - self.update_packed_mapping(enable_lora=vllm_config.lora_config is not None) config: Qwen3_5MoeConfig = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config multimodal_config = vllm_config.model_config.multimodal_config From 9c68c6287d1625a8189195862b02ba84ff487bae Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 23 Mar 2026 17:46:38 +0000 Subject: [PATCH 3/4] clean Signed-off-by: Isotr0py --- vllm/lora/model_manager.py | 38 +++++++++++++------------------------- 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index 88a14b6458a0..6a1e31d420a0 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -556,32 +556,20 @@ def create_dummy_lora( replacements = self.packed_modules_mapping[parts[-1]] n_slices = getattr(module, "n_slices", len(replacements)) subloras: list[LoRALayerWeights | None] = [] + # HACK: overrides replacements for qkvz = qkv + z case. + # Any better methods to handle this case? if n_slices != len(replacements): - # When a packed module has more slices than replacements - # (e.g. in_proj_qkvz has 4 slices but only 2 replacements), - # create one dummy sublora per slice so that set_lora - # receives len(lora_b) == n_slices without expansion. - for i in range(n_slices): - lora = LoRALayerWeights.create_dummy_lora_weights( - module_name + f".slice_{i}", - module.lora_a_stacked[i].shape[-1], - module.lora_b_stacked[i].shape[-2], - rank, - module.lora_a_stacked[i].dtype, - "cpu", - ) - subloras.append(lora) - else: - for i, r in enumerate(replacements): - lora = LoRALayerWeights.create_dummy_lora_weights( - module_name + "." + r, - module.lora_a_stacked[i].shape[-1], - module.lora_b_stacked[i].shape[-2], - rank, - module.lora_a_stacked[i].dtype, - "cpu", - ) - subloras.append(lora) + replacements = [f"slice_{i}" for i in range(n_slices)] + for i, r in enumerate(replacements): + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name + "." + r, + module.lora_a_stacked[i].shape[-1], + module.lora_b_stacked[i].shape[-2], + rank, + module.lora_a_stacked[i].dtype, + "cpu", + ) + subloras.append(lora) if module.__class__.__name__ == "FusedMoEWithLoRA": # For non-gated MoE, pad subloras to 3 elements per expert # to match pack_moe expectations (w1, w2, None for w3) From 42f1d1bef1a03c17f5fa364c40e6befc353076ca Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 24 Mar 2026 01:48:41 +0800 Subject: [PATCH 4/4] clean Signed-off-by: Isotr0py --- vllm/lora/layers/column_parallel_linear.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index cc1ae17cf478..e78137ee2a5f 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -262,9 +262,6 @@ def expand_packed_lora( Expand packed adapter groups when they don't match n_slices. E.g. in_proj_qkv (covers Q+K+V) + in_proj_z """ - # FIXME(Isotr0py): Currently, we assume multiple slices are always - # like qkv in qkvz (start from 0). We need to think about what if - # slices don't start from 0 in the future. expanded_a: list[torch.Tensor] = [] expanded_b: list[torch.Tensor] = [] start_idx = 0