diff --git a/tests/kernels/moe/test_moe_layer.py b/tests/kernels/moe/test_moe_layer.py index 07c04a168026..14cfd00c2bdd 100644 --- a/tests/kernels/moe/test_moe_layer.py +++ b/tests/kernels/moe/test_moe_layer.py @@ -1620,18 +1620,13 @@ def _parallel_worker( else: print("F", end="") finally: - # Note: for some reason DeepEP buffers don't seem to be - # entirely reusable on B200. In order to work around this - # we clear the all2all manager's cache after each testpoint. - cap = current_platform.get_device_capability() - if ( - cap is not None - and cap.major == 10 - and ( - test_config.backend == "deepep_low_latency" - or test_config.backend == "deepep_high_throughput" - ) - ): + # DeepEP managers are not reliably reusable across many subtests in + # a single worker process. Tear them down after each DeepEP case so + # later subtests do not inherit stale communication state. + if test_config.backend in { + "deepep_low_latency", + "deepep_high_throughput", + }: torch.accelerator.synchronize() all2all_manager = get_ep_group().device_communicator.all2all_manager if all2all_manager is not None: diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index c2b4f5515644..a0028687a32f 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -44,6 +44,7 @@ VocabParallelEmbedding, get_masked_input_and_mask, ) +from vllm.model_executor.models.deepseek_v2 import DeepSeekV2FusedQkvAProjLinear from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed @@ -1422,7 +1423,107 @@ def test_variable_slice_lora_class_selection(default_vllm_config, dist_init): f"for 2 packed modules, got {type(selected_layer_merged).__name__}" ) - # Case 5: Plain ColumnParallelLinear (not merged) - common in many models + fully_sharded_tp_lora_config = LoRAConfig( + max_loras=8, + max_lora_rank=16, + lora_dtype=torch.float16, + fully_sharded_loras=True, + ) + fully_sharded_tp_layer = MergedColumnParallelLinear( + 4096, [2048, 2048], bias=False, params_dtype=torch.float16 + ) + fully_sharded_tp_layer.tp_size = 2 + + assert not MergedColumnParallelLinearWithLoRA.can_replace_layer( + source_layer=fully_sharded_tp_layer, + lora_config=fully_sharded_tp_lora_config, + packed_modules_list=packed_modules_two, + ), "Generic merged wrapper should reject fully sharded TP layers" + + assert MergedColumnParallelLinearWithShardedLoRA.can_replace_layer( + source_layer=fully_sharded_tp_layer, + lora_config=fully_sharded_tp_lora_config, + packed_modules_list=packed_modules_two, + ), "Sharded merged wrapper should remain eligible for fully sharded TP layers" + + selected_fully_sharded_tp_layer = from_layer( + fully_sharded_tp_layer, + max_loras=8, + lora_config=fully_sharded_tp_lora_config, + packed_modules_list=packed_modules_two, + ) + assert isinstance( + selected_fully_sharded_tp_layer, + MergedColumnParallelLinearWithShardedLoRA, + ), ( + "from_layer should select MergedColumnParallelLinearWithShardedLoRA " + "for fully sharded TP merged layers, got " + f"{type(selected_fully_sharded_tp_layer).__name__}" + ) + + # Case 5: DeepSeek's fused_qkv_a_proj should reuse the generic merged + # wrapper while preserving its custom base forward path. + deepseek_fused_layer = DeepSeekV2FusedQkvAProjLinear( + 4096, [2048, 2048], prefix="model.layers.0.self_attn.fused_qkv_a_proj" + ) + selected_deepseek_layer = from_layer( + deepseek_fused_layer, + max_loras=8, + lora_config=lora_config, + packed_modules_list=packed_modules_two, + ) + assert isinstance(selected_deepseek_layer, MergedColumnParallelLinearWithLoRA), ( + "from_layer should select MergedColumnParallelLinearWithLoRA " + f"for DeepSeek fused_qkv_a_proj, got {type(selected_deepseek_layer).__name__}" + ) + + fully_sharded_lora_config = LoRAConfig( + max_loras=8, + max_lora_rank=16, + lora_dtype=torch.float16, + fully_sharded_loras=True, + ) + selected_fully_sharded_deepseek_layer = from_layer( + deepseek_fused_layer, + max_loras=8, + lora_config=fully_sharded_lora_config, + packed_modules_list=packed_modules_two, + ) + assert isinstance( + selected_fully_sharded_deepseek_layer, + MergedColumnParallelLinearWithLoRA, + ), ( + "from_layer should keep using MergedColumnParallelLinearWithLoRA " + "for fused_qkv_a_proj when the base layer is effectively unsharded, got " + f"{type(selected_fully_sharded_deepseek_layer).__name__}" + ) + + # Case 6: Generic subclass of MergedColumnParallelLinear with 2 packed + # modules should still use the generic merged wrapper. + class CustomMergedColumnParallelLinear(MergedColumnParallelLinear): + pass + + custom_merged_layer = CustomMergedColumnParallelLinear( + 4096, [2048, 2048], bias=False, params_dtype=torch.float16 + ) + assert MergedColumnParallelLinearWithLoRA.can_replace_layer( + source_layer=custom_merged_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_two, + ), "MergedColumnParallelLinearWithLoRA should handle subclasses" + + selected_custom_layer = from_layer( + custom_merged_layer, + max_loras=8, + lora_config=lora_config, + packed_modules_list=packed_modules_two, + ) + assert isinstance(selected_custom_layer, MergedColumnParallelLinearWithLoRA), ( + f"from_layer should select MergedColumnParallelLinearWithLoRA " + f"for subclassed merged layers, got {type(selected_custom_layer).__name__}" + ) + + # Case 7: Plain ColumnParallelLinear (not merged) - common in many models # -> ColumnParallelLinearWithLoRA should be selected plain_column_parallel = ColumnParallelLinear( 4096, 4096, bias=False, params_dtype=torch.float16 @@ -1455,7 +1556,7 @@ def test_variable_slice_lora_class_selection(default_vllm_config, dist_init): f"for plain ColumnParallelLinear, got {type(selected_plain).__name__}" ) - # Case 6: MergedColumnParallelLinear with exactly 2 output sizes + # Case 8: MergedColumnParallelLinear with exactly 2 output sizes # and empty packed_modules_list # -> ColumnParallelLinearWithLoRA should NOT match (packed_modules_list != 1) # -> MergedColumnParallelLinearVariableSliceWithLoRA should NOT match (< 3 slices) @@ -1473,3 +1574,170 @@ def test_variable_slice_lora_class_selection(default_vllm_config, dist_init): "MergedColumnParallelLinearVariableSliceWithLoRA " "should NOT handle 2 slices even with empty packed_modules_list" ) + + +@pytest.mark.parametrize( + "wrapper_cls", + [ColumnParallelLinearWithLoRA, ColumnParallelLinearWithShardedLoRA], +) +def test_get_and_maybe_dequant_weights_accepts_lora_wrappers(dist_init, wrapper_cls): + from vllm.model_executor.layers.quantization.utils.quant_utils import ( + get_and_maybe_dequant_weights, + ) + + linear = ColumnParallelLinear(4096, 4096, bias=False, params_dtype=torch.float16) + lora_linear = wrapper_cls(linear) + + # Should work with LoRA wrappers and return [out, in] weights. + dequant_weight = get_and_maybe_dequant_weights(lora_linear, out_dtype=torch.float16) + assert dequant_weight.shape == linear.weight.shape + + +@torch.inference_mode() +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("stage", STAGES) +@pytest.mark.parametrize("fully_sharded", [False, True]) +def test_deepseek_fused_qkv_a_proj_lora_preserves_base_forward( + default_vllm_config, dist_init, device, stage, fully_sharded +): + if current_platform.is_cuda_alike(): + torch.accelerator.set_device_index(device) + + torch.set_default_device(device) + dtype = torch.float16 if current_platform.is_cuda_alike() else torch.float32 + max_loras = 8 + lora_config = LoRAConfig( + max_loras=max_loras, + max_lora_rank=8, + lora_dtype=dtype, + fully_sharded_loras=fully_sharded, + ) + punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config) + assert check_punica_wrapper(punica_wrapper) + + class OffsetDeepSeekFusedQkvAProjLinear(DeepSeekV2FusedQkvAProjLinear): + def forward(self, input_): + output, output_bias = super().forward(input_) + return output + 1, output_bias + + layer = OffsetDeepSeekFusedQkvAProjLinear( + 32, [16, 16], prefix="model.layers.0.self_attn.fused_qkv_a_proj" + ) + layer.weight.data = torch.rand_like(layer.weight.data, dtype=dtype) + + lora_layer = MergedColumnParallelLinearWithLoRA(layer) + lora_layer.create_lora_weights(max_loras, lora_config) + lora_layer.set_mapping(punica_wrapper) + + id_to_index = get_random_id_to_index(1, max_loras, log=False) + active_slot = next(i for i, lora_id in enumerate(id_to_index) if lora_id == 1) + lora_a = [ + torch.rand(8, 32, dtype=dtype, device=device), + torch.rand(8, 32, dtype=dtype, device=device), + ] + lora_b = [ + torch.rand(16, 8, dtype=dtype, device=device), + torch.rand(16, 8, dtype=dtype, device=device), + ] + lora_layer.set_lora(active_slot, lora_a=lora_a, lora_b=lora_b) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[1], + num_inputs=4, + input_size=(1, 32), + input_range=(0, 1), + input_type=dtype, + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512) + + lora_result = lora_layer(torch.cat(inputs))[0] + + expected_results = [] + for input_ in inputs: + result = layer(input_)[0] + result[:, :16] += input_ @ lora_a[0].T @ lora_b[0].T + result[:, 16:] += input_ @ lora_a[1].T @ lora_b[1].T + expected_results.append(result) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close( + lora_result, torch.cat(expected_results), rtol=rtol, atol=atol + ) + + merged_layer = OffsetDeepSeekFusedQkvAProjLinear( + 32, [16, 16], prefix="model.layers.0.self_attn.fused_qkv_a_proj" + ) + merged_layer.weight.data = layer.weight.data.clone() + merged_layer.weight.data[:16].add_(lora_b[0] @ lora_a[0]) + merged_layer.weight.data[16:].add_(lora_b[1] @ lora_a[1]) + merged_result = merged_layer(torch.cat(inputs))[0] + + torch.testing.assert_close(lora_result, merged_result, rtol=rtol, atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("stage", STAGES) +def test_replicated_lora_preserves_base_forward_for_subclasses( + default_vllm_config, dist_init, device, stage +): + if current_platform.is_cuda_alike(): + torch.accelerator.set_device_index(device) + + torch.set_default_device(device) + dtype = torch.float16 if current_platform.is_cuda_alike() else torch.float32 + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=dtype) + punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config) + assert check_punica_wrapper(punica_wrapper) + + class OffsetReplicatedLinear(ReplicatedLinear): + def forward(self, input_): + output, output_bias = super().forward(input_) + return output + 1, output_bias + + layer = OffsetReplicatedLinear(32, 16, bias=False, params_dtype=dtype) + layer.weight.data = torch.rand_like(layer.weight.data, dtype=dtype) + + lora_layer = ReplicatedLinearWithLoRA(layer) + lora_layer.create_lora_weights(max_loras, lora_config) + lora_layer.set_mapping(punica_wrapper) + + id_to_index = get_random_id_to_index(1, max_loras, log=False) + active_slot = next(i for i, lora_id in enumerate(id_to_index) if lora_id == 1) + lora_a = torch.rand(8, 32, dtype=dtype, device=device) + lora_b = torch.rand(16, 8, dtype=dtype, device=device) + lora_layer.set_lora(active_slot, lora_a=lora_a, lora_b=lora_b) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[1], + num_inputs=4, + input_size=(1, 32), + input_range=(0, 1), + input_type=dtype, + device=device, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512) + + lora_result = lora_layer(torch.cat(inputs))[0] + + expected_results = [] + for input_ in inputs: + result = layer(input_)[0] + result += input_ @ lora_a.T @ lora_b.T + expected_results.append(result) + + rtol, atol = TOLERANCES[lora_result.dtype] + torch.testing.assert_close( + lora_result, torch.cat(expected_results), rtol=rtol, atol=atol + ) + + merged_layer = OffsetReplicatedLinear(32, 16, bias=False, params_dtype=dtype) + merged_layer.weight.data = layer.weight.data.clone() + merged_layer.weight.data.add_(lora_b @ lora_a) + merged_result = merged_layer(torch.cat(inputs))[0] + + torch.testing.assert_close(lora_result, merged_result, rtol=rtol, atol=atol) diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index e80d96f00e74..1c07dc4ae677 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -13,6 +13,7 @@ from vllm.lora.layers import ( ColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA, + ReplicatedLinearWithLoRA, RowParallelLinearWithLoRA, ) from vllm.lora.lora_model import LoRAModel @@ -26,6 +27,7 @@ from vllm.lora.peft_helper import PEFTHelper from vllm.lora.request import LoRARequest from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager, WorkerLoRAManager +from vllm.model_executor.layers.fused_moe import GateLinear from vllm.platforms import current_platform from .utils import create_peft_lora @@ -132,6 +134,135 @@ def test_replace_submodules(default_vllm_config, dist_init, dummy_model): assert isinstance(model.get_submodule("layer1.dense2"), RowParallelLinearWithLoRA) +def test_wrap_replicated_linear_subclasses(default_vllm_config, dist_init, dummy_model): + from vllm.model_executor.layers.linear import ReplicatedLinear + + class CustomReplicatedLinear(ReplicatedLinear): + pass + + model = dummy_model + model.add_module("custom_gate", CustomReplicatedLinear(10, 10, bias=False)) + + manager = LoRAModelManager( + model, + 1, + 1, + 1, + LoRAConfig( + max_lora_rank=8, max_cpu_loras=8, max_loras=8, lora_dtype=DEFAULT_DTYPE + ), + torch.device(DEVICES[0]), + ) + + assert isinstance( + manager.model.get_submodule("custom_gate"), ReplicatedLinearWithLoRA + ) + + +def test_wrap_gate_linear(default_vllm_config, dist_init, dummy_model): + model = dummy_model + model.add_module("router_gate", GateLinear(10, 4, bias=False)) + + manager = LoRAModelManager( + model, + 1, + 1, + 1, + LoRAConfig( + max_lora_rank=8, max_cpu_loras=8, max_loras=8, lora_dtype=DEFAULT_DTYPE + ), + torch.device(DEVICES[0]), + ) + + assert isinstance( + manager.model.get_submodule("router_gate"), ReplicatedLinearWithLoRA + ) + + +def test_skip_unsupported_matched_modules(default_vllm_config, dist_init, dummy_model): + class UnsupportedContainer(nn.Module): + def __init__(self): + super().__init__() + # This name matches a supported target suffix ("dense1"), + # but nn.Linear is not currently a LoRA-wrappable layer type. + self.dense1 = nn.Linear(10, 10, bias=False) + + model = dummy_model + model.add_module("unsupported", UnsupportedContainer()) + + manager = LoRAModelManager( + model, + 1, + 1, + 1, + LoRAConfig( + max_lora_rank=8, max_cpu_loras=8, max_loras=8, lora_dtype=DEFAULT_DTYPE + ), + torch.device(DEVICES[0]), + ) + + # Should not crash and should keep unsupported matched modules unchanged. + assert isinstance(manager.model.get_submodule("unsupported.dense1"), nn.Linear) + assert "unsupported.dense1" not in manager.modules + + +def test_target_modules_fail_closed_on_unsupported_matched_modules( + default_vllm_config, dist_init, dummy_model +): + class UnsupportedContainer(nn.Module): + def __init__(self): + super().__init__() + self.dense1 = nn.Linear(10, 10, bias=False) + + model = dummy_model + model.add_module("unsupported", UnsupportedContainer()) + + with pytest.raises(ValueError, match="unsupported.dense1"): + LoRAModelManager( + model, + 1, + 1, + 1, + LoRAConfig( + max_lora_rank=8, + max_cpu_loras=8, + max_loras=8, + lora_dtype=DEFAULT_DTYPE, + target_modules=["dense1"], + ), + torch.device(DEVICES[0]), + ) + + +def test_get_dummy_lora_warmup_rank_for_fully_sharded_moe(): + manager = LoRAModelManager.__new__(LoRAModelManager) + manager.lora_config = LoRAConfig( + max_lora_rank=64, + max_cpu_loras=1, + max_loras=1, + lora_dtype=DEFAULT_DTYPE, + fully_sharded_loras=True, + ) + + class DummyModule: + def __init__(self, tp_size: int, fully_sharded: bool): + self.tp_size = tp_size + self.fully_sharded = fully_sharded + + manager.modules = { + "model.layers.0.self_attn.q_proj": DummyModule( + tp_size=32, + fully_sharded=True, + ), + "model.layers.0.mlp.experts": DummyModule( + tp_size=32, + fully_sharded=True, + ), + } + + assert manager.get_dummy_lora_warmup_rank(8) == 32 + + @pytest.mark.parametrize("device", DEVICES) def test_lora_model_manager(default_vllm_config, dist_init, dummy_model, device): model = dummy_model @@ -795,6 +926,25 @@ def test_target_modules_none_uses_all( ) +@pytest.mark.parametrize("device", DEVICES) +def test_target_modules_match_packed_runtime_modules( + default_vllm_config, dist_init, dummy_model_gate_up, device +): + """Packed runtime modules should be selected by their adapter-visible names.""" + _test_target_modules( + dummy_model_gate_up, + ["gate_proj"], + device, + expected_lora=[("gate_up_proj", MergedColumnParallelLinearWithLoRA)], + expected_no_lora=[ + ("dense1", ColumnParallelLinearWithLoRA), + ("dense2", RowParallelLinearWithLoRA), + ("layer1.dense1", ColumnParallelLinearWithLoRA), + ("layer1.dense2", RowParallelLinearWithLoRA), + ], + ) + + @pytest.mark.parametrize("device", DEVICES) def test_load_adapter_warns_on_unsupported_modules( default_vllm_config, dist_init, dummy_model_gate_up, device, tmp_path diff --git a/tests/lora/test_lora_utils.py b/tests/lora/test_lora_utils.py index da66aa60b0d8..603ec9297491 100644 --- a/tests/lora/test_lora_utils.py +++ b/tests/lora/test_lora_utils.py @@ -58,3 +58,24 @@ def test_exact_name_match(self): def test_exact_name_no_match(self): assert not is_in_target_modules("dense3", ["dense1", "dense2"]) + + def test_packed_parent_matches_child_target_modules(self): + assert is_in_target_modules( + "model.layers.0.mlp.gate_up_proj", + ["gate_proj", "up_proj"], + {"gate_up_proj": ["gate_proj", "up_proj"]}, + ) + + def test_packed_child_matches_parent_target_modules(self): + assert is_in_target_modules( + "model.layers.0.mlp.gate_proj", + ["gate_up_proj"], + {"gate_up_proj": ["gate_proj", "up_proj"]}, + ) + + def test_fused_parent_matches_child_target_modules(self): + assert is_in_target_modules( + "model.layers.0.self_attn.fused_qkv_a_proj", + ["q_a_proj", "kv_a_proj_with_mqa"], + {"fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"]}, + ) diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py index 4ea6b1ec8f05..68783ae50d4b 100644 --- a/vllm/lora/layers/base_linear.py +++ b/vllm/lora/layers/base_linear.py @@ -203,7 +203,16 @@ def _apply_sync( self, x: torch.Tensor, bias: torch.Tensor | None = None ) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) + return self._apply_lora_to_output(x, output) + def _apply_base_forward(self, x: torch.Tensor) -> torch.Tensor: + base_output = self.base_layer(x) + output = base_output[0] if isinstance(base_output, tuple) else base_output + return self._apply_lora_to_output(x, output) + + def _apply_lora_to_output( + self, x: torch.Tensor, output: torch.Tensor + ) -> torch.Tensor: original_shape = output.shape if output.ndim == 3 else None # In transformers backend, x and output have extra batch dimension like diff --git a/vllm/lora/layers/column_parallel_linear.py b/vllm/lora/layers/column_parallel_linear.py index f49a3fcbb941..aed6b5ba891e 100644 --- a/vllm/lora/layers/column_parallel_linear.py +++ b/vllm/lora/layers/column_parallel_linear.py @@ -40,11 +40,19 @@ def _mcp_apply(x, bias, layer: "ColumnParallelLinearWithLoRA"): # Since communication is needed, the buffer is directly initialized as a # tensor rather than a tuple of tensor. - buffers = torch.zeros( - (layer.n_slices, x.shape[0], layer.lora_a_stacked[0].shape[2]), + local_lora_rank = layer.lora_a_stacked[0].shape[2] + buffer_shape = (layer.n_slices, x.shape[0], local_lora_rank) + # Under torch.compile, the local-rank-1 fully-sharded path can otherwise + # get lowered to a reinterpret view with a non-canonical layout. The + # Triton shrink op mutates this buffer in place and expects the standard + # contiguous [slice, token, rank] stride contract. + buffers = torch.empty_strided( + buffer_shape, + (x.shape[0] * local_lora_rank, local_lora_rank, 1), dtype=torch.float32, device=x.device, ) + buffers.zero_() shrunk_buffers: torch.Tensor | None = layer.punica_wrapper.add_shrink( buffers, x, layer.lora_a_stacked, 1.0 @@ -86,7 +94,7 @@ def __init__(self, base_layer: ColumnParallelLinear) -> None: # The base_layer type is ColumnParallelLinear or # MergedColumnParallelLinear, their weight sharding logic is # inconsistent when TP is greater than 1. - self.is_merged_col_linear = type(base_layer) is MergedColumnParallelLinear + self.is_merged_col_linear = isinstance(base_layer, MergedColumnParallelLinear) self.output_size = self.base_layer.output_size_per_partition # There is only one LoRA layer self.n_slices = 1 @@ -158,7 +166,7 @@ def can_replace_layer( ) -> bool: if type(source_layer) is maybe_get_oot_by_class(ColumnParallelLinear): return True - if type(source_layer) is maybe_get_oot_by_class(MergedColumnParallelLinear): + if isinstance(source_layer, maybe_get_oot_by_class(MergedColumnParallelLinear)): if len(packed_modules_list) != 1: return False # Exclude layers with 3+ output sizes - those are handled by @@ -275,19 +283,41 @@ def set_lora( index, 0, : lora_b_i.shape[0], : lora_b_i.shape[1] ].copy_(lora_b_i, non_blocking=True) + def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: + merged_cls = maybe_get_oot_by_class(MergedColumnParallelLinear) + # Effectively unsharded subclasses can safely reuse their custom + # forward() implementation before applying the LoRA delta. + if ( + self.tp_size == 1 + and type(self.base_layer) is not merged_cls + and type(self.base_layer).forward is not merged_cls.forward + ): + return self._apply_base_forward(x) + return _mcp_apply(x, bias, self) + @classmethod - @_not_fully_sharded_can_replace def can_replace_layer( cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: list, model_config: PretrainedConfig | None = None, + decorate: bool = True, ) -> bool: - return ( - type(source_layer) is MergedColumnParallelLinear - and len(packed_modules_list) == 2 - ) + merged_cls = maybe_get_oot_by_class(MergedColumnParallelLinear) + if not isinstance(source_layer, merged_cls) or len(packed_modules_list) != 2: + return False + + tp_size = getattr(source_layer, "tp_size", 1) + if type(source_layer) is merged_cls: + if not decorate: + return True + return not lora_config.fully_sharded_loras or tp_size == 1 + + # Only support effectively unsharded subclasses here. Sharded + # subclasses may have custom communication semantics that the generic + # merged-column LoRA path does not know how to preserve. + return tp_size == 1 class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): @@ -607,7 +637,9 @@ def can_replace_layer( ) -> bool: # Support MergedColumnParallelLinear with 3 or more slices # (2 slices are handled by MergedColumnParallelLinearWithLoRA) - if type(source_layer) is not maybe_get_oot_by_class(MergedColumnParallelLinear): + if not isinstance( + source_layer, maybe_get_oot_by_class(MergedColumnParallelLinear) + ): return False # If packed_modules_list has 3+ items, use this class diff --git a/vllm/lora/layers/replicated_linear.py b/vllm/lora/layers/replicated_linear.py index f1f499b841ba..53ae26be4c36 100644 --- a/vllm/lora/layers/replicated_linear.py +++ b/vllm/lora/layers/replicated_linear.py @@ -46,6 +46,12 @@ def forward( return output, output_bias + def apply(self, x: torch.Tensor, bias: torch.Tensor | None = None) -> torch.Tensor: + # ReplicatedLinear subclasses such as GateLinear override forward() to + # dispatch custom kernels and/or adjust the output dtype. Apply LoRA on + # top of the actual base-layer output instead of bypassing that path. + return self._apply_base_forward(x) + # ReplicatedLinear should always be replaced, regardless of the fully # sharded LoRAs setting, because it is, by definition, copied per GPU. @classmethod @@ -56,7 +62,7 @@ def can_replace_layer( packed_modules_list: list, model_config: PretrainedConfig | None = None, ) -> bool: - return type(source_layer) is maybe_get_oot_by_class(ReplicatedLinear) + return isinstance(source_layer, maybe_get_oot_by_class(ReplicatedLinear)) def slice_lora_a( self, lora_a: torch.Tensor | list[torch.Tensor | None] diff --git a/vllm/lora/model_manager.py b/vllm/lora/model_manager.py index 9d3772560433..3b58031dcbab 100644 --- a/vllm/lora/model_manager.py +++ b/vllm/lora/model_manager.py @@ -437,12 +437,21 @@ def _parent_module(module_name: str) -> str: ), ) - # In some models, especially multimodal ones, layers with the same - # name may have different types, such as nn.Linear and - # ReplicatedLinear. The nn.Linear layers cannot be replaced with - # LoRA layers, leading to assertion error. The following check - # aims to prevent this error - if self.supports_mm and not isinstance(new_module, BaseLayerWithLoRA): + # Some matched modules can be unsupported by LoRA wrappers + # (e.g. subclasses with specialized forward behavior). + if not isinstance(new_module, BaseLayerWithLoRA): + error_msg = ( + "LoRA target module " + f"{module_name} ({type(module).__name__}) matched the " + "deployment configuration but could not be wrapped by any " + "LoRA layer implementation." + ) + if self.lora_config.target_modules is not None: + raise ValueError( + f"{error_msg} target_modules=" + f"{sorted(self.lora_config.target_modules)}" + ) + logger.warning_once("%s It will be ignored.", error_msg) continue self.register_module(module_name, new_module) @@ -578,6 +587,38 @@ def create_dummy_lora( model.loras[module_name] = lora return model + def get_dummy_lora_warmup_rank(self, default_rank: int) -> int: + """Return a dummy LoRA rank compatible with wrapped modules. + + Dummy LoRAs keep warmup memory low by using a small rank. Fully + sharded MoE wrappers additionally require the dummy rank to be divisible + by tensor parallel size because they shard W13 along the rank axis. + """ + if not self.lora_config.fully_sharded_loras: + return default_rank + + required_multiple = 1 + for module in self.modules.values(): + if not getattr(module, "fully_sharded", False): + continue + required_multiple = math.lcm(required_multiple, module.tp_size) + + if required_multiple == 1 or default_rank % required_multiple == 0: + return default_rank + + adjusted_rank = ( + (default_rank + required_multiple - 1) // required_multiple + ) * required_multiple + if adjusted_rank > self.lora_config.max_lora_rank: + raise ValueError( + "Unable to choose a dummy LoRA warmup rank compatible with " + "fully sharded MoE modules: " + f"default_rank={default_rank}, " + f"required_multiple={required_multiple}, " + f"max_lora_rank={self.lora_config.max_lora_rank}" + ) + return adjusted_rank + def _match_target_modules(self, module_name: str) -> bool: """Check if a module should have LoRA applied. @@ -594,7 +635,11 @@ def _match_target_modules(self, module_name: str) -> bool: """ if not is_supported_lora_module(module_name, self.supported_lora_modules): return False - return is_in_target_modules(module_name, self.lora_config.target_modules) + return is_in_target_modules( + module_name, + self.lora_config.target_modules, + self.packed_modules_mapping, + ) def _get_punica_wrapper(self, module_name: str) -> PunicaWrapperBase | None: """ diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 2349ace70846..2991447a6ad4 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -73,7 +73,9 @@ def get_lora_id(): return _GLOBAL_LORA_ID -_all_lora_classes: set[type[BaseLayerWithLoRA]] = { +# Order matters here: more specific wrappers must be checked before generic +# merged/column-parallel wrappers in from_layer(). +_all_lora_classes: tuple[type[BaseLayerWithLoRA], ...] = ( VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA, MergedColumnParallelLinearWithLoRA, @@ -90,7 +92,7 @@ def get_lora_id(): RowParallelLinearWithShardedLoRA, FusedMoEWithLoRA, FusedMoE3DWithLoRA, -} +) def is_moe_model(model: nn.Module) -> bool: @@ -258,6 +260,7 @@ def is_supported_lora_module( def is_in_target_modules( module_name: str, target_modules: list[str] | None, + packed_modules_mapping: dict[str, list[str]] | None = None, ) -> bool: """Check if a module passes the deployment-time target_modules filter. @@ -268,14 +271,33 @@ def is_in_target_modules( module_name: Full dot-separated module name. target_modules: Optional deployment-time restriction list from LoRAConfig.target_modules. + packed_modules_mapping: Optional model-defined mapping from packed + runtime module names to their adapter-visible submodule names + (e.g. ``{"gate_up_proj": ["gate_proj", "up_proj"]}``). Returns: True if the module passes the filter, False otherwise. """ if target_modules is None: return True + target_module_set = set(target_modules) module_suffix = module_name.split(".")[-1] - return module_suffix in set(target_modules) + if module_suffix in target_module_set or module_name in target_module_set: + return True + + if not packed_modules_mapping: + return False + + # Runtime packed parent matched by deployment-time child targets. + packed_children = packed_modules_mapping.get(module_suffix) + if packed_children and any(child in target_module_set for child in packed_children): + return True + + # Adapter-visible packed child matched by deployment-time parent target. + return any( + module_suffix in children and packed_parent in target_module_set + for packed_parent, children in packed_modules_mapping.items() + ) def get_adapter_absolute_path(lora_path: str) -> str: diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index bea6d015e0a6..6d8ef2db51ad 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -160,7 +160,11 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel: lora_request.lora_path, ", ".join(sorted(expected_lora_modules_lst)), ) - elif not is_in_target_modules(module_name, target_modules): + elif not is_in_target_modules( + module_name, + target_modules, + packed_modules_mapping, + ): logger.warning_once( "LoRA module '%s' in adapter '%s' is not in the " "deployment-time target_modules restriction [%s]." @@ -197,6 +201,9 @@ def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: self._cached_dummy_lora = dummy_lora return self._adapter_manager.add_adapter(dummy_lora) + def get_dummy_lora_warmup_rank(self, default_rank: int) -> int: + return self._adapter_manager.get_dummy_lora_warmup_rank(default_rank) + def pin_adapter(self, adapter_id: int) -> bool: return self._adapter_manager.pin_adapter(adapter_id) diff --git a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py index 8fcb8fa1da14..af7cb7baf963 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py +++ b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py @@ -214,6 +214,19 @@ def _return_or_raise( return backend, k_cls raise ValueError(_make_log_unsupported(backend, reason)) + # LoRA needs Triton's unfused activation/reduction hooks. Selecting the + # backend here ensures weights stay in a LoRA-compatible layout instead of + # being permuted for a backend like FlashInfer or AITER during load. + if moe_config.is_lora_enabled: + backend = UnquantizedMoeBackend.TRITON + if activation_format == mk.FusedMoEActivationFormat.BatchedExperts: + backend = UnquantizedMoeBackend.BATCHED_TRITON + return _return_or_raise( + backend, + moe_config, + activation_format, + ) + runner_backend = moe_config.moe_backend if runner_backend != "auto": requested_backend = map_unquantized_backend(runner_backend) diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index de76deb191db..f57eb39f42b4 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -356,6 +356,12 @@ def get_and_maybe_dequant_weights( from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod + # LoRA linear wrappers store quantization metadata on `base_layer`. + # Unwrap here so callers can pass either a raw linear layer or its LoRA + # wrapper without special-casing. + while hasattr(layer, "base_layer") and hasattr(layer.base_layer, "quant_method"): + layer = layer.base_layer + weight = get_attribute_fallback(layer, ["weight", "qweight", "weight_packed"]) # Unquantized layer: just return base weights diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index 53873d156f88..3a14abfc3589 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -101,9 +101,12 @@ def maybe_setup_dummy_loras( assert self.lora_manager is not None, "LoRA is not enabled" num_loras = lora_config.max_loras - lora_warmup_rank = ( + lora_warmup_rank: int = ( lora_config.max_lora_rank if lora_config.max_lora_rank < 8 else 8 ) + lora_warmup_rank = self.lora_manager.get_dummy_lora_warmup_rank( + lora_warmup_rank + ) # Make dummy lora requests lora_requests: set[LoRARequest] = { LoRARequest(