Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions tests/kernels/moe/test_moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1620,18 +1620,13 @@ def _parallel_worker(
else:
print("F", end="")
finally:
# Note: for some reason DeepEP buffers don't seem to be
# entirely reusable on B200. In order to work around this
# we clear the all2all manager's cache after each testpoint.
cap = current_platform.get_device_capability()
if (
cap is not None
and cap.major == 10
and (
test_config.backend == "deepep_low_latency"
or test_config.backend == "deepep_high_throughput"
)
):
# DeepEP managers are not reliably reusable across many subtests in
# a single worker process. Tear them down after each DeepEP case so
# later subtests do not inherit stale communication state.
if test_config.backend in {
"deepep_low_latency",
"deepep_high_throughput",
}:
Comment on lines 1622 to +1629
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this breaks the CI #40637

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I don't think so as all the CIs were passed before merge, including the specific one you mentioned: https://buildkite.com/vllm/ci/builds/62466/steps/canvas?sid=019db40c-c48c-4238-b1ca-827533eb7d09&tab=output

Also d22887b was introduced specifically for fixing that CI.

Copy link
Copy Markdown
Contributor Author

@HollowMan6 HollowMan6 Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohh make sense, the nightly build happens at 2am

torch.accelerator.synchronize()
all2all_manager = get_ep_group().device_communicator.all2all_manager
if all2all_manager is not None:
Expand Down
272 changes: 270 additions & 2 deletions tests/lora/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
VocabParallelEmbedding,
get_masked_input_and_mask,
)
from vllm.model_executor.models.deepseek_v2 import DeepSeekV2FusedQkvAProjLinear
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_random_seed

Expand Down Expand Up @@ -1422,7 +1423,107 @@ def test_variable_slice_lora_class_selection(default_vllm_config, dist_init):
f"for 2 packed modules, got {type(selected_layer_merged).__name__}"
)

# Case 5: Plain ColumnParallelLinear (not merged) - common in many models
fully_sharded_tp_lora_config = LoRAConfig(
max_loras=8,
max_lora_rank=16,
lora_dtype=torch.float16,
fully_sharded_loras=True,
)
fully_sharded_tp_layer = MergedColumnParallelLinear(
4096, [2048, 2048], bias=False, params_dtype=torch.float16
)
fully_sharded_tp_layer.tp_size = 2

assert not MergedColumnParallelLinearWithLoRA.can_replace_layer(
source_layer=fully_sharded_tp_layer,
lora_config=fully_sharded_tp_lora_config,
packed_modules_list=packed_modules_two,
), "Generic merged wrapper should reject fully sharded TP layers"

assert MergedColumnParallelLinearWithShardedLoRA.can_replace_layer(
source_layer=fully_sharded_tp_layer,
lora_config=fully_sharded_tp_lora_config,
packed_modules_list=packed_modules_two,
), "Sharded merged wrapper should remain eligible for fully sharded TP layers"

selected_fully_sharded_tp_layer = from_layer(
fully_sharded_tp_layer,
max_loras=8,
lora_config=fully_sharded_tp_lora_config,
packed_modules_list=packed_modules_two,
)
assert isinstance(
selected_fully_sharded_tp_layer,
MergedColumnParallelLinearWithShardedLoRA,
), (
"from_layer should select MergedColumnParallelLinearWithShardedLoRA "
"for fully sharded TP merged layers, got "
f"{type(selected_fully_sharded_tp_layer).__name__}"
)

# Case 5: DeepSeek's fused_qkv_a_proj should reuse the generic merged
# wrapper while preserving its custom base forward path.
deepseek_fused_layer = DeepSeekV2FusedQkvAProjLinear(
4096, [2048, 2048], prefix="model.layers.0.self_attn.fused_qkv_a_proj"
)
selected_deepseek_layer = from_layer(
deepseek_fused_layer,
max_loras=8,
lora_config=lora_config,
packed_modules_list=packed_modules_two,
)
assert isinstance(selected_deepseek_layer, MergedColumnParallelLinearWithLoRA), (
"from_layer should select MergedColumnParallelLinearWithLoRA "
f"for DeepSeek fused_qkv_a_proj, got {type(selected_deepseek_layer).__name__}"
)

fully_sharded_lora_config = LoRAConfig(
max_loras=8,
max_lora_rank=16,
lora_dtype=torch.float16,
fully_sharded_loras=True,
)
selected_fully_sharded_deepseek_layer = from_layer(
deepseek_fused_layer,
max_loras=8,
lora_config=fully_sharded_lora_config,
packed_modules_list=packed_modules_two,
)
assert isinstance(
selected_fully_sharded_deepseek_layer,
MergedColumnParallelLinearWithLoRA,
), (
"from_layer should keep using MergedColumnParallelLinearWithLoRA "
"for fused_qkv_a_proj when the base layer is effectively unsharded, got "
f"{type(selected_fully_sharded_deepseek_layer).__name__}"
)

# Case 6: Generic subclass of MergedColumnParallelLinear with 2 packed
# modules should still use the generic merged wrapper.
class CustomMergedColumnParallelLinear(MergedColumnParallelLinear):
pass

custom_merged_layer = CustomMergedColumnParallelLinear(
4096, [2048, 2048], bias=False, params_dtype=torch.float16
)
assert MergedColumnParallelLinearWithLoRA.can_replace_layer(
source_layer=custom_merged_layer,
lora_config=lora_config,
packed_modules_list=packed_modules_two,
), "MergedColumnParallelLinearWithLoRA should handle subclasses"

selected_custom_layer = from_layer(
custom_merged_layer,
max_loras=8,
lora_config=lora_config,
packed_modules_list=packed_modules_two,
)
assert isinstance(selected_custom_layer, MergedColumnParallelLinearWithLoRA), (
f"from_layer should select MergedColumnParallelLinearWithLoRA "
f"for subclassed merged layers, got {type(selected_custom_layer).__name__}"
)

# Case 7: Plain ColumnParallelLinear (not merged) - common in many models
# -> ColumnParallelLinearWithLoRA should be selected
plain_column_parallel = ColumnParallelLinear(
4096, 4096, bias=False, params_dtype=torch.float16
Expand Down Expand Up @@ -1455,7 +1556,7 @@ def test_variable_slice_lora_class_selection(default_vllm_config, dist_init):
f"for plain ColumnParallelLinear, got {type(selected_plain).__name__}"
)

# Case 6: MergedColumnParallelLinear with exactly 2 output sizes
# Case 8: MergedColumnParallelLinear with exactly 2 output sizes
# and empty packed_modules_list
# -> ColumnParallelLinearWithLoRA should NOT match (packed_modules_list != 1)
# -> MergedColumnParallelLinearVariableSliceWithLoRA should NOT match (< 3 slices)
Expand All @@ -1473,3 +1574,170 @@ def test_variable_slice_lora_class_selection(default_vllm_config, dist_init):
"MergedColumnParallelLinearVariableSliceWithLoRA "
"should NOT handle 2 slices even with empty packed_modules_list"
)


@pytest.mark.parametrize(
"wrapper_cls",
[ColumnParallelLinearWithLoRA, ColumnParallelLinearWithShardedLoRA],
)
def test_get_and_maybe_dequant_weights_accepts_lora_wrappers(dist_init, wrapper_cls):
from vllm.model_executor.layers.quantization.utils.quant_utils import (
get_and_maybe_dequant_weights,
)

linear = ColumnParallelLinear(4096, 4096, bias=False, params_dtype=torch.float16)
lora_linear = wrapper_cls(linear)

# Should work with LoRA wrappers and return [out, in] weights.
dequant_weight = get_and_maybe_dequant_weights(lora_linear, out_dtype=torch.float16)
assert dequant_weight.shape == linear.weight.shape


@torch.inference_mode()
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
@pytest.mark.parametrize("fully_sharded", [False, True])
def test_deepseek_fused_qkv_a_proj_lora_preserves_base_forward(
default_vllm_config, dist_init, device, stage, fully_sharded
):
if current_platform.is_cuda_alike():
torch.accelerator.set_device_index(device)

torch.set_default_device(device)
dtype = torch.float16 if current_platform.is_cuda_alike() else torch.float32
max_loras = 8
lora_config = LoRAConfig(
max_loras=max_loras,
max_lora_rank=8,
lora_dtype=dtype,
fully_sharded_loras=fully_sharded,
)
punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
assert check_punica_wrapper(punica_wrapper)

class OffsetDeepSeekFusedQkvAProjLinear(DeepSeekV2FusedQkvAProjLinear):
def forward(self, input_):
output, output_bias = super().forward(input_)
return output + 1, output_bias

layer = OffsetDeepSeekFusedQkvAProjLinear(
32, [16, 16], prefix="model.layers.0.self_attn.fused_qkv_a_proj"
)
layer.weight.data = torch.rand_like(layer.weight.data, dtype=dtype)

lora_layer = MergedColumnParallelLinearWithLoRA(layer)
lora_layer.create_lora_weights(max_loras, lora_config)
lora_layer.set_mapping(punica_wrapper)

id_to_index = get_random_id_to_index(1, max_loras, log=False)
active_slot = next(i for i, lora_id in enumerate(id_to_index) if lora_id == 1)
lora_a = [
torch.rand(8, 32, dtype=dtype, device=device),
torch.rand(8, 32, dtype=dtype, device=device),
]
lora_b = [
torch.rand(16, 8, dtype=dtype, device=device),
torch.rand(16, 8, dtype=dtype, device=device),
]
lora_layer.set_lora(active_slot, lora_a=lora_a, lora_b=lora_b)

inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[1],
num_inputs=4,
input_size=(1, 32),
input_range=(0, 1),
input_type=dtype,
device=device,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512)

lora_result = lora_layer(torch.cat(inputs))[0]

expected_results = []
for input_ in inputs:
result = layer(input_)[0]
result[:, :16] += input_ @ lora_a[0].T @ lora_b[0].T
result[:, 16:] += input_ @ lora_a[1].T @ lora_b[1].T
expected_results.append(result)

rtol, atol = TOLERANCES[lora_result.dtype]
torch.testing.assert_close(
lora_result, torch.cat(expected_results), rtol=rtol, atol=atol
)

merged_layer = OffsetDeepSeekFusedQkvAProjLinear(
32, [16, 16], prefix="model.layers.0.self_attn.fused_qkv_a_proj"
)
merged_layer.weight.data = layer.weight.data.clone()
merged_layer.weight.data[:16].add_(lora_b[0] @ lora_a[0])
merged_layer.weight.data[16:].add_(lora_b[1] @ lora_a[1])
merged_result = merged_layer(torch.cat(inputs))[0]

torch.testing.assert_close(lora_result, merged_result, rtol=rtol, atol=atol)


@torch.inference_mode()
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("stage", STAGES)
def test_replicated_lora_preserves_base_forward_for_subclasses(
default_vllm_config, dist_init, device, stage
):
if current_platform.is_cuda_alike():
torch.accelerator.set_device_index(device)

torch.set_default_device(device)
dtype = torch.float16 if current_platform.is_cuda_alike() else torch.float32
max_loras = 8
lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, lora_dtype=dtype)
punica_wrapper = get_punica_wrapper(8192, 256, device, lora_config=lora_config)
assert check_punica_wrapper(punica_wrapper)

class OffsetReplicatedLinear(ReplicatedLinear):
def forward(self, input_):
output, output_bias = super().forward(input_)
return output + 1, output_bias

layer = OffsetReplicatedLinear(32, 16, bias=False, params_dtype=dtype)
layer.weight.data = torch.rand_like(layer.weight.data, dtype=dtype)

lora_layer = ReplicatedLinearWithLoRA(layer)
lora_layer.create_lora_weights(max_loras, lora_config)
lora_layer.set_mapping(punica_wrapper)

id_to_index = get_random_id_to_index(1, max_loras, log=False)
active_slot = next(i for i, lora_id in enumerate(id_to_index) if lora_id == 1)
lora_a = torch.rand(8, 32, dtype=dtype, device=device)
lora_b = torch.rand(16, 8, dtype=dtype, device=device)
lora_layer.set_lora(active_slot, lora_a=lora_a, lora_b=lora_b)

inputs, index_mapping, prompt_mapping = create_random_inputs(
active_lora_ids=[1],
num_inputs=4,
input_size=(1, 32),
input_range=(0, 1),
input_type=dtype,
device=device,
)
lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=stage)
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, 512)

lora_result = lora_layer(torch.cat(inputs))[0]

expected_results = []
for input_ in inputs:
result = layer(input_)[0]
result += input_ @ lora_a.T @ lora_b.T
expected_results.append(result)

rtol, atol = TOLERANCES[lora_result.dtype]
torch.testing.assert_close(
lora_result, torch.cat(expected_results), rtol=rtol, atol=atol
)

merged_layer = OffsetReplicatedLinear(32, 16, bias=False, params_dtype=dtype)
merged_layer.weight.data = layer.weight.data.clone()
merged_layer.weight.data.add_(lora_b @ lora_a)
merged_result = merged_layer(torch.cat(inputs))[0]

torch.testing.assert_close(lora_result, merged_result, rtol=rtol, atol=atol)
Loading
Loading