Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
2fcd14c
Using active-loras in grid in fused_moe_lora kernel
yugong333 Dec 17, 2025
ddf1597
Capture multiple cuda graph across various active loras
yugong333 Dec 18, 2025
9101205
Clean code
yugong333 Dec 18, 2025
41e3438
Capture multiple cuda graph across various active loras
yugong333 Dec 18, 2025
f0c5d30
Construting grid with num of active lora
yugong333 Dec 23, 2025
45f86bb
remove the constraint for SD
yugong333 Dec 23, 2025
becb897
Explain specialize_lora_count
yugong333 Dec 30, 2025
9af88f5
fix the error
yugong333 Dec 30, 2025
934291c
Construting grid with num of active lora
yugong333 Dec 23, 2025
194cb60
Adding argument specialize-active-lora to choose lora kernel grid des…
yugong333 Jan 9, 2026
bbd7f79
clean code
yugong333 Jan 9, 2026
c82a267
Fix bugs
yugong333 Jan 12, 2026
fdd66de
Fix error of not to cover the last lora slot
yugong333 Jan 13, 2026
0f95148
fix the mismatch if effective loras is max_loras + 1
yugong333 Jan 13, 2026
8548e7c
fix the mismatch if effective loras is max_loras + 1
yugong333 Jan 13, 2026
32ba60f
Fix errors
yugong333 Jan 26, 2026
8b71dfe
fix bugs in cuda graph
yugong333 Jan 26, 2026
202903e
Removing files
yugong333 Jan 26, 2026
4d35dda
revert unnecessary changes
yugong333 Jan 27, 2026
4925e33
Fix bugs in punica_gpu.py
yugong333 Jan 27, 2026
a6934e4
Fixing bugs
yugong333 Jan 27, 2026
29eb24c
Remove duplicated code
yugong333 Jan 27, 2026
813ee49
Clean code
yugong333 Jan 27, 2026
18fc597
Cleaning code
yugong333 Jan 28, 2026
585cc6c
Cleaning code
yugong333 Jan 28, 2026
607d31b
clean code
yugong333 Jan 29, 2026
599eaef
Updating test_fused_moe_lora_kernel.py
yugong333 Jan 29, 2026
8a27da6
reset captured_lora_counts
yugong333 Jan 29, 2026
a68972e
Fixing bugs in test files
yugong333 Jan 29, 2026
25ca977
fix bugs in cudagraph_specialize_lora
yugong333 Jan 31, 2026
91773b6
Fix errors
yugong333 Feb 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tests/lora/test_fused_moe_lora_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ def use_fused_moe_lora_kernel(
expert_ids = expert_ids.view(max_loras, -1)
sorted_token_ids = sorted_token_ids.view(max_loras, -1)

# num_active_loras is the number of active LoRAs
# (max_loras + 1 to include no-lora case)
num_active_loras = max_loras + 1

fused_moe_lora(
output,
hidden_states,
Expand All @@ -194,6 +198,7 @@ def use_fused_moe_lora_kernel(
max_lora_rank,
top_k_num,
lora_ids,
num_active_loras,
adapter_enabled,
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
Expand Down Expand Up @@ -376,6 +381,10 @@ def use_fused_moe_lora_kernel_naive(
adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32)
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32)

# num_active_loras is the number of active LoRAs
# (max_loras + 1 to include no-lora case)
num_active_loras = max_loras + 1

fused_moe_lora(
output,
hidden_states,
Expand All @@ -389,6 +398,7 @@ def use_fused_moe_lora_kernel_naive(
max_lora_rank,
top_k_num,
lora_ids,
num_active_loras,
adapter_enabled,
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
Expand Down
4 changes: 2 additions & 2 deletions tests/lora/test_punica_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def check_lora_shrink_kernel(
data.inputs_tensor,
data.lora_weights,
out_tensor,
*lora_meta.meta_args(token_nums=token_nums),
*lora_meta.meta_args(token_nums=token_nums, specialize_active_lora=False),
scaling,
)

Expand Down Expand Up @@ -234,7 +234,7 @@ def check_lora_expand_kernel(
data.inputs_tensor,
data.lora_weights,
out_tensor,
*lora_meta.meta_args(token_nums=token_nums),
*lora_meta.meta_args(token_nums=token_nums, specialize_active_lora=False),
offset_start=0,
add_inputs=add_inputs,
)
Expand Down
15 changes: 13 additions & 2 deletions tests/v1/cudagraph/test_cudagraph_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
SchedulerConfig,
VllmConfig,
)
from vllm.config.lora import LoRAConfig
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher
Expand Down Expand Up @@ -47,6 +48,12 @@ def _create_vllm_config(
mock_config.speculative_config = None # No speculative decoding
if not lora_config:
mock_config.lora_config = None
else:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks for adding this test case.

# Create a real LoRAConfig with specialize_active_lora enabled
mock_config.lora_config = LoRAConfig(
max_loras=4,
specialize_active_lora=True,
)
# Mimic the behavior of VllmConfig.__post_init__()
if compilation_config.mode == CompilationMode.VLLM_COMPILE:
compilation_config.set_splitting_ops_for_v1(
Expand Down Expand Up @@ -106,15 +113,19 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
)

# Verify the key is initialized correctly
# With LoRA specialization (max_loras=4, specialize_active_lora=True):
# - lora_cases = [0, 1, 2, 4, 5] (no-lora + powers of 2 up to 4 + max_loras+1)
# - capture_sizes = [1, 8]
# - Total keys = 2 sizes × 5 lora_cases = 10
if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
4 if lora_config else 2
10 if lora_config else 2
)
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
4 if lora_config else 2
10 if lora_config else 2
)
else:
assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0
Expand Down
7 changes: 7 additions & 0 deletions vllm/config/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ class LoRAConfig:
of multimodal models will be enabled. This is an experimental feature and
currently only supports some MM models such as the Qwen VL series. The default
is False."""
specialize_active_lora: bool = False
"""Whether to construct lora kernel grid by the number of active LoRA adapters.
When set to True, separate cuda graphs will be captured for different counts
of active LoRAs (powers of 2 up to max_loras), which can improve performance
for variable LoRA usage patterns at the cost of increased startup time and
memory usage. Only takes effect when cudagraph_specialize_lora is True.
"""

def compute_hash(self) -> str:
"""
Expand Down
5 changes: 5 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ class EngineArgs:
max_cpu_loras: int | None = LoRAConfig.max_cpu_loras
lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype
enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora
specialize_active_lora: bool = LoRAConfig.specialize_active_lora

ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override
Expand Down Expand Up @@ -1026,6 +1027,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"]
)
lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"])
lora_group.add_argument(
"--specialize-active-lora", **lora_kwargs["specialize_active_lora"]
)

# Observability arguments
observability_kwargs = get_kwargs(ObservabilityConfig)
Expand Down Expand Up @@ -1657,6 +1661,7 @@ def create_engine_config(
fully_sharded_loras=self.fully_sharded_loras,
lora_dtype=self.lora_dtype,
enable_tower_connector_lora=self.enable_tower_connector_lora,
specialize_active_lora=self.specialize_active_lora,
max_cpu_loras=self.max_cpu_loras
if self.max_cpu_loras and self.max_cpu_loras > 0
else None,
Expand Down
14 changes: 13 additions & 1 deletion vllm/forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,26 @@ class BatchDescriptor(NamedTuple):
"""
Whether this batch has active LoRA adapters.
"""
num_active_loras: int = 0
"""
Number of distinct active LoRA adapters in this batch.
When cudagraph_specialize_lora_count is enabled, separate CUDA graphs
are captured for each num_active_loras value. This allows kernels
(like fused_moe_lora) whose grid size depends on num_active_loras
to be properly captured.
"""

def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor":
"""
Return a relaxed version of current batch descriptor that is still compatible
with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs).
"""
return BatchDescriptor(
self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora
self.num_tokens,
num_reqs=None,
uniform=False,
has_lora=self.has_lora,
num_active_loras=self.num_active_loras,
)


Expand Down
20 changes: 16 additions & 4 deletions vllm/lora/ops/triton_ops/fused_moe_lora_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):


def _adjust_kernel_inputs(
max_loras: int,
num_active_loras: int,
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
):
Expand All @@ -109,7 +109,7 @@ def _adjust_kernel_inputs(
else:
stride_tl = sorted_token_ids.stride(0)
stride_el = expert_ids.stride(0)
grid_lora_dim = max_loras + 1
grid_lora_dim = num_active_loras
return grid_lora_dim, stride_tl, stride_el


Expand Down Expand Up @@ -354,6 +354,7 @@ def _fused_moe_lora_shrink(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
use_gdc: bool = False,
) -> None:
Expand All @@ -373,7 +374,7 @@ def _fused_moe_lora_shrink(
b_ptr = _get_ptr(lora_a_stacked, device)

grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs(
w1_lora_a_stacked.shape[0], sorted_token_ids, expert_ids
num_active_loras, sorted_token_ids, expert_ids
)
grid = lambda META: (
split_k
Expand Down Expand Up @@ -457,6 +458,7 @@ def _fused_moe_lora_expand(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
offset: int = 0,
use_gdc: bool = False,
Expand Down Expand Up @@ -484,7 +486,7 @@ def _fused_moe_lora_expand(
}

grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs(
w1_lora_b_stacked.shape[0], sorted_token_ids, expert_ids
num_active_loras, sorted_token_ids, expert_ids
)

grid = lambda META: (
Expand Down Expand Up @@ -557,6 +559,7 @@ def _fused_moe_lora(
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
num_active_loras: int,
adapter_enabled: torch.Tensor,
shrink_block_size_m: int,
shrink_block_size_n: int,
Expand Down Expand Up @@ -648,6 +651,7 @@ def _fused_moe_lora(
shrink_num_warps,
shrink_num_stages,
shrink_split_k,
num_active_loras,
mul_routed_weight,
use_gdc=use_gdc,
)
Expand Down Expand Up @@ -695,6 +699,7 @@ def _fused_moe_lora(
expand_num_warps,
expand_num_stages,
expand_split_k,
num_active_loras,
mul_routed_weight,
offset,
use_gdc=use_gdc,
Expand All @@ -714,6 +719,7 @@ def _fused_moe_lora_fake(
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
num_active_loras: int,
adapter_enabled: torch.Tensor,
shrink_block_size_m: int,
shrink_block_size_n: int,
Expand All @@ -730,6 +736,8 @@ def _fused_moe_lora_fake(
expand_num_stages: int,
expand_split_k: int,
mul_routed_weight: bool = False,
fully_sharded: bool = False,
offset: int = 0,
) -> None:
return

Expand Down Expand Up @@ -761,6 +769,7 @@ def _fused_moe_lora_shrink_fake(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
use_gdc: bool = False,
) -> None:
Expand All @@ -770,6 +779,7 @@ def _fused_moe_lora_shrink_fake(
def _fused_moe_lora_expand_fake(
output: torch.Tensor,
a_intermediate_cache1: torch.Tensor,
b_intermediate_cache1: torch.Tensor,
lora_b_stacked: list[torch.Tensor],
topk_weights: torch.Tensor,
sorted_token_ids: torch.Tensor | None,
Expand All @@ -796,7 +806,9 @@ def _fused_moe_lora_expand_fake(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
mul_routed_weight: bool = False,
offset: int = 0,
use_gdc: bool = False,
) -> None:
return
Expand Down
7 changes: 3 additions & 4 deletions vllm/lora/ops/triton_ops/lora_expand_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def _lora_expand(
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
Expand Down Expand Up @@ -234,10 +235,7 @@ def _lora_expand(
grid = (
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
NUM_SLICES,
# Each LoRA receives its own set of thread blocks for output
# computation. If some LoRA doesn't have any tokens to process, its
# thread blocks simply exit.
MAX_LORAS,
num_active_loras,
)
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance.
Expand Down Expand Up @@ -291,6 +289,7 @@ def _lora_expand_fake(
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
num_active_loras: int,
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
Expand Down
Loading