diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index c97421a3f51a..dc3602007dc3 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -181,6 +181,10 @@ def use_fused_moe_lora_kernel( expert_ids = expert_ids.view(max_loras, -1) sorted_token_ids = sorted_token_ids.view(max_loras, -1) + # num_active_loras is the number of active LoRAs + # (max_loras + 1 to include no-lora case) + num_active_loras = max_loras + 1 + fused_moe_lora( output, hidden_states, @@ -194,6 +198,7 @@ def use_fused_moe_lora_kernel( max_lora_rank, top_k_num, lora_ids, + num_active_loras, adapter_enabled, config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], @@ -376,6 +381,10 @@ def use_fused_moe_lora_kernel_naive( adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) lora_ids = torch.arange(max_loras + 2, dtype=torch.int32) + # num_active_loras is the number of active LoRAs + # (max_loras + 1 to include no-lora case) + num_active_loras = max_loras + 1 + fused_moe_lora( output, hidden_states, @@ -389,6 +398,7 @@ def use_fused_moe_lora_kernel_naive( max_lora_rank, top_k_num, lora_ids, + num_active_loras, adapter_enabled, config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], diff --git a/tests/lora/test_punica_ops.py b/tests/lora/test_punica_ops.py index 5083f500c5cd..963260367671 100644 --- a/tests/lora/test_punica_ops.py +++ b/tests/lora/test_punica_ops.py @@ -161,7 +161,7 @@ def check_lora_shrink_kernel( data.inputs_tensor, data.lora_weights, out_tensor, - *lora_meta.meta_args(token_nums=token_nums), + *lora_meta.meta_args(token_nums=token_nums, specialize_active_lora=False), scaling, ) @@ -234,7 +234,7 @@ def check_lora_expand_kernel( data.inputs_tensor, data.lora_weights, out_tensor, - *lora_meta.meta_args(token_nums=token_nums), + *lora_meta.meta_args(token_nums=token_nums, specialize_active_lora=False), offset_start=0, add_inputs=add_inputs, ) diff --git a/tests/v1/cudagraph/test_cudagraph_dispatch.py b/tests/v1/cudagraph/test_cudagraph_dispatch.py index 2f539c9d397d..2b0f8a95d49f 100644 --- a/tests/v1/cudagraph/test_cudagraph_dispatch.py +++ b/tests/v1/cudagraph/test_cudagraph_dispatch.py @@ -17,6 +17,7 @@ SchedulerConfig, VllmConfig, ) +from vllm.config.lora import LoRAConfig from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.platforms import current_platform from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher @@ -47,6 +48,12 @@ def _create_vllm_config( mock_config.speculative_config = None # No speculative decoding if not lora_config: mock_config.lora_config = None + else: + # Create a real LoRAConfig with specialize_active_lora enabled + mock_config.lora_config = LoRAConfig( + max_loras=4, + specialize_active_lora=True, + ) # Mimic the behavior of VllmConfig.__post_init__() if compilation_config.mode == CompilationMode.VLLM_COMPILE: compilation_config.set_splitting_ops_for_v1( @@ -106,15 +113,19 @@ def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config): ) # Verify the key is initialized correctly + # With LoRA specialization (max_loras=4, specialize_active_lora=True): + # - lora_cases = [0, 1, 2, 4, 5] (no-lora + powers of 2 up to 4 + max_loras+1) + # - capture_sizes = [1, 8] + # - Total keys = 2 sizes × 5 lora_cases = 10 if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]: assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == ( - 4 if lora_config else 2 + 10 if lora_config else 2 ) else: assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0 if cudagraph_mode_str not in ["NONE", "PIECEWISE"]: assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == ( - 4 if lora_config else 2 + 10 if lora_config else 2 ) else: assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0 diff --git a/vllm/config/lora.py b/vllm/config/lora.py index 56aa08fcb273..f15beffe1df5 100644 --- a/vllm/config/lora.py +++ b/vllm/config/lora.py @@ -60,6 +60,13 @@ class LoRAConfig: of multimodal models will be enabled. This is an experimental feature and currently only supports some MM models such as the Qwen VL series. The default is False.""" + specialize_active_lora: bool = False + """Whether to construct lora kernel grid by the number of active LoRA adapters. + When set to True, separate cuda graphs will be captured for different counts + of active LoRAs (powers of 2 up to max_loras), which can improve performance + for variable LoRA usage patterns at the cost of increased startup time and + memory usage. Only takes effect when cudagraph_specialize_lora is True. + """ def compute_hash(self) -> str: """ diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 30eb472ca547..f3e7729f64e3 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -485,6 +485,7 @@ class EngineArgs: max_cpu_loras: int | None = LoRAConfig.max_cpu_loras lora_dtype: str | torch.dtype | None = LoRAConfig.lora_dtype enable_tower_connector_lora: bool = LoRAConfig.enable_tower_connector_lora + specialize_active_lora: bool = LoRAConfig.specialize_active_lora ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight num_gpu_blocks_override: int | None = CacheConfig.num_gpu_blocks_override @@ -1026,6 +1027,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "--fully-sharded-loras", **lora_kwargs["fully_sharded_loras"] ) lora_group.add_argument("--default-mm-loras", **lora_kwargs["default_mm_loras"]) + lora_group.add_argument( + "--specialize-active-lora", **lora_kwargs["specialize_active_lora"] + ) # Observability arguments observability_kwargs = get_kwargs(ObservabilityConfig) @@ -1657,6 +1661,7 @@ def create_engine_config( fully_sharded_loras=self.fully_sharded_loras, lora_dtype=self.lora_dtype, enable_tower_connector_lora=self.enable_tower_connector_lora, + specialize_active_lora=self.specialize_active_lora, max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras and self.max_cpu_loras > 0 else None, diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 20af24c2ca06..9a831a2e20ef 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -47,6 +47,14 @@ class BatchDescriptor(NamedTuple): """ Whether this batch has active LoRA adapters. """ + num_active_loras: int = 0 + """ + Number of distinct active LoRA adapters in this batch. + When cudagraph_specialize_lora_count is enabled, separate CUDA graphs + are captured for each num_active_loras value. This allows kernels + (like fused_moe_lora) whose grid size depends on num_active_loras + to be properly captured. + """ def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor": """ @@ -54,7 +62,11 @@ def relax_for_mixed_batch_cudagraphs(self) -> "BatchDescriptor": with PIECEWISE cudagraphs (or mixed prefill-decode FA cudagraphs). """ return BatchDescriptor( - self.num_tokens, num_reqs=None, uniform=False, has_lora=self.has_lora + self.num_tokens, + num_reqs=None, + uniform=False, + has_lora=self.has_lora, + num_active_loras=self.num_active_loras, ) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 9e76d742b1e0..3b90b3f9d74e 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -95,7 +95,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device): def _adjust_kernel_inputs( - max_loras: int, + num_active_loras: int, sorted_token_ids: torch.Tensor | None, expert_ids: torch.Tensor, ): @@ -109,7 +109,7 @@ def _adjust_kernel_inputs( else: stride_tl = sorted_token_ids.stride(0) stride_el = expert_ids.stride(0) - grid_lora_dim = max_loras + 1 + grid_lora_dim = num_active_loras return grid_lora_dim, stride_tl, stride_el @@ -354,6 +354,7 @@ def _fused_moe_lora_shrink( num_warps: int, num_stages: int, split_k: int, + num_active_loras: int, mul_routed_weight: bool = False, use_gdc: bool = False, ) -> None: @@ -373,7 +374,7 @@ def _fused_moe_lora_shrink( b_ptr = _get_ptr(lora_a_stacked, device) grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs( - w1_lora_a_stacked.shape[0], sorted_token_ids, expert_ids + num_active_loras, sorted_token_ids, expert_ids ) grid = lambda META: ( split_k @@ -457,6 +458,7 @@ def _fused_moe_lora_expand( num_warps: int, num_stages: int, split_k: int, + num_active_loras: int, mul_routed_weight: bool = False, offset: int = 0, use_gdc: bool = False, @@ -484,7 +486,7 @@ def _fused_moe_lora_expand( } grid_lora_dim, stride_tl, stride_el = _adjust_kernel_inputs( - w1_lora_b_stacked.shape[0], sorted_token_ids, expert_ids + num_active_loras, sorted_token_ids, expert_ids ) grid = lambda META: ( @@ -557,6 +559,7 @@ def _fused_moe_lora( max_lora_rank: int, top_k_num: int, lora_ids: torch.Tensor, + num_active_loras: int, adapter_enabled: torch.Tensor, shrink_block_size_m: int, shrink_block_size_n: int, @@ -648,6 +651,7 @@ def _fused_moe_lora( shrink_num_warps, shrink_num_stages, shrink_split_k, + num_active_loras, mul_routed_weight, use_gdc=use_gdc, ) @@ -695,6 +699,7 @@ def _fused_moe_lora( expand_num_warps, expand_num_stages, expand_split_k, + num_active_loras, mul_routed_weight, offset, use_gdc=use_gdc, @@ -714,6 +719,7 @@ def _fused_moe_lora_fake( max_lora_rank: int, top_k_num: int, lora_ids: torch.Tensor, + num_active_loras: int, adapter_enabled: torch.Tensor, shrink_block_size_m: int, shrink_block_size_n: int, @@ -730,6 +736,8 @@ def _fused_moe_lora_fake( expand_num_stages: int, expand_split_k: int, mul_routed_weight: bool = False, + fully_sharded: bool = False, + offset: int = 0, ) -> None: return @@ -761,6 +769,7 @@ def _fused_moe_lora_shrink_fake( num_warps: int, num_stages: int, split_k: int, + num_active_loras: int, mul_routed_weight: bool = False, use_gdc: bool = False, ) -> None: @@ -770,6 +779,7 @@ def _fused_moe_lora_shrink_fake( def _fused_moe_lora_expand_fake( output: torch.Tensor, a_intermediate_cache1: torch.Tensor, + b_intermediate_cache1: torch.Tensor, lora_b_stacked: list[torch.Tensor], topk_weights: torch.Tensor, sorted_token_ids: torch.Tensor | None, @@ -796,7 +806,9 @@ def _fused_moe_lora_expand_fake( num_warps: int, num_stages: int, split_k: int, + num_active_loras: int, mul_routed_weight: bool = False, + offset: int = 0, use_gdc: bool = False, ) -> None: return diff --git a/vllm/lora/ops/triton_ops/lora_expand_op.py b/vllm/lora/ops/triton_ops/lora_expand_op.py index 862f5f6b2431..1557d37d2126 100644 --- a/vllm/lora/ops/triton_ops/lora_expand_op.py +++ b/vllm/lora/ops/triton_ops/lora_expand_op.py @@ -138,6 +138,7 @@ def _lora_expand( lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_ids: torch.Tensor, # shape [max-loras + 1] no_lora_flag_cpu: torch.Tensor, # shape [1] + num_active_loras: int, # number of active LoRAs (unused here, for API compat) offset_start: int = 0, add_inputs: bool = False, ) -> None: @@ -234,10 +235,7 @@ def _lora_expand( grid = ( triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N), NUM_SLICES, - # Each LoRA receives its own set of thread blocks for output - # computation. If some LoRA doesn't have any tokens to process, its - # thread blocks simply exit. - MAX_LORAS, + num_active_loras, ) # We disable PDL temporarily because LoRA kernels are not launching back-to-back, # making PDL invalid and affecting the kernel performance. @@ -291,6 +289,7 @@ def _lora_expand_fake( lora_token_start_loc: torch.Tensor, lora_ids: torch.Tensor, no_lora_flag_cpu: torch.Tensor, + num_active_loras: int, offset_start: int = 0, add_inputs: bool = False, ) -> None: diff --git a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py index c3bef7680dd0..1fec1d50c1a1 100644 --- a/vllm/lora/ops/triton_ops/lora_kernel_metadata.py +++ b/vllm/lora/ops/triton_ops/lora_kernel_metadata.py @@ -4,7 +4,8 @@ LoRA kernels metadata preparation utilities. """ -from dataclasses import dataclass +import bisect +from dataclasses import dataclass, field import torch @@ -28,9 +29,22 @@ class LoRAKernelMeta: # to early exit from inside the lora_expand / lora_shrink torch operation. no_lora_flag_cpu: torch.Tensor + # Number of active LoRAs (unique non-(-1) values in token_lora_mapping) + # Stored as a Python int to avoid GPU->CPU sync during forward pass + num_active_loras: int = 0 + + # Captured LoRA counts for cudagraph specialization (sorted list). + # When specialize_active_lora is enabled, num_active_loras is rounded up + # to the nearest value in this list to match cudagraph capture keys. + # Empty list means no specialization (use actual count). + captured_lora_counts: list[int] = field(default_factory=list) + @staticmethod def make( - max_loras: int, max_num_tokens: int, device: torch.device | str + max_loras: int, + max_num_tokens: int, + device: torch.device | str, + captured_lora_counts: list[int] | None = None, ) -> "LoRAKernelMeta": token_lora_mapping = torch.empty( max_num_tokens, dtype=torch.int32, device=device @@ -66,6 +80,9 @@ def make( num_tokens_per_lora=num_tokens_per_lora, lora_token_start_loc=lora_token_start_loc, no_lora_flag_cpu=no_lora_flag_cpu, + captured_lora_counts=sorted(captured_lora_counts) + if captured_lora_counts + else [], ) def _reset(self): @@ -73,6 +90,8 @@ def _reset(self): self.num_tokens_per_lora.fill_(0) self.lora_token_start_loc.fill_(0) self.no_lora_flag_cpu.fill_(False) + self.num_active_loras = 0 + self.captured_lora_counts = [] def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: """ @@ -118,6 +137,15 @@ def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: num_tokens_per_lora, non_blocking=True ) + self.num_active_loras = lora_ids.size(0) + + # Round up num_active_loras to match cudagraph capture keys. + # This ensures the kernel grid dimension matches the captured graph. + if self.captured_lora_counts and self.num_active_loras > 0: + idx = bisect.bisect_left(self.captured_lora_counts, self.num_active_loras) + if idx < len(self.captured_lora_counts): + self.num_active_loras = self.captured_lora_counts[idx] + # lora_token_start_loc lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0) self.lora_token_start_loc[1 : 1 + lora_token_start_loc.size(0)].copy_( @@ -125,7 +153,9 @@ def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: ) def meta_args( - self, token_nums: int + self, + token_nums: int, + specialize_active_lora: bool, ) -> tuple[ torch.Tensor, torch.Tensor, @@ -133,6 +163,7 @@ def meta_args( torch.Tensor, torch.Tensor, torch.Tensor, + int, ]: """ This function returns the kernel metadata required for the current @@ -144,6 +175,7 @@ def meta_args( token_nums (int): Number of input tokens in the current forward pass of the kernel. """ + max_loras = self.active_lora_ids.size(0) - 1 return ( self.token_lora_mapping[:token_nums], self.token_indices_sorted_by_lora_ids[:token_nums], @@ -151,4 +183,5 @@ def meta_args( self.lora_token_start_loc, self.active_lora_ids, self.no_lora_flag_cpu, + self.num_active_loras if specialize_active_lora else max_loras + 1, ) diff --git a/vllm/lora/ops/triton_ops/lora_shrink_op.py b/vllm/lora/ops/triton_ops/lora_shrink_op.py index 9ba82b396a48..8dbd988f7685 100644 --- a/vllm/lora/ops/triton_ops/lora_shrink_op.py +++ b/vllm/lora/ops/triton_ops/lora_shrink_op.py @@ -134,6 +134,7 @@ def _lora_shrink( lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_ids: torch.Tensor, # shape [max-loras + 1] no_lora_flag_cpu: torch.Tensor, # shape [1] + num_active_loras: int, # number of active LoRAs (unused here, for API compat) scaling: float, ) -> None: """ @@ -214,10 +215,7 @@ def _lora_shrink( grid = ( SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N), NUM_SLICES, - # Each LoRA receives its own set of thread blocks for output - # computation. If some LoRA doesn't have any tokens to process, its - # thread blocks exit early. - MAX_LORAS, + num_active_loras, ) # We disable PDL temporarily because LoRA kernels are not launching back-to-back, # making PDL invalid and affecting the kernel performance. @@ -269,6 +267,7 @@ def _lora_shrink_fake( lora_token_start_loc: torch.Tensor, lora_ids: torch.Tensor, no_lora_flag_cpu: torch.Tensor, + num_active_loras: int, scaling: float, ) -> None: return diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index b704a74c7568..b75d297ba5c4 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -12,6 +12,7 @@ import torch from vllm.lora.layers import LoRAMapping +from vllm.lora.utils import get_captured_lora_counts from vllm.triton_utils import HAS_TRITON, triton from vllm.utils.math_utils import round_up @@ -48,8 +49,16 @@ def __init__( self.lora_config = kwargs["lora_config"] self.max_loras = self.lora_config.max_loras + # Compute captured LoRA counts for cudagraph specialization. + captured_lora_counts = get_captured_lora_counts( + self.max_loras, self.lora_config.specialize_active_lora + ) + self.token_mapping_meta = LoRAKernelMeta.make( - self.max_loras, max_num_batched_tokens, device=device + self.max_loras, + max_num_batched_tokens, + device=device, + captured_lora_counts=captured_lora_counts, ) # When speculative decoding is enabled, max_num_samples is @@ -57,7 +66,10 @@ def __init__( # This line can be optimized by replacing max_num_batched_tokens # to max_batches * (num_speculative_decoding_tokens + 1). self.prompt_mapping_meta = LoRAKernelMeta.make( - self.max_loras, max_num_batched_tokens, device=device + self.max_loras, + max_num_batched_tokens, + device=device, + captured_lora_counts=captured_lora_counts, ) def update_metadata( @@ -102,7 +114,9 @@ def add_shrink( x, lora_a_stacked, y, - *self.token_mapping_meta.meta_args(x.size(0)), + *self.token_mapping_meta.meta_args( + x.size(0), self.lora_config.specialize_active_lora + ), scale, ) @@ -143,7 +157,9 @@ def add_expand( x, lora_b_stacked, y, - *self.token_mapping_meta.meta_args(num_tokens), + *self.token_mapping_meta.meta_args( + num_tokens, self.lora_config.specialize_active_lora + ), offset_start=offset_start, add_inputs=True, ) @@ -175,7 +191,9 @@ def add_lora_embedding( x.unsqueeze(dim=0), (lora_b_stacked,), y, - *self.token_mapping_meta.meta_args(x.size(0)), + *self.token_mapping_meta.meta_args( + x.size(0), self.lora_config.specialize_active_lora + ), offset_start=0, add_inputs=add_inputs, ) @@ -287,7 +305,9 @@ def add_lora_logits( x, [lora_a_stacked], buffer.unsqueeze(dim=0), - *self.prompt_mapping_meta.meta_args(x.size(0)), + *self.prompt_mapping_meta.meta_args( + x.size(0), self.lora_config.specialize_active_lora + ), scale, ) @@ -295,7 +315,9 @@ def add_lora_logits( buffer.unsqueeze(dim=0), [lora_b_stacked], y, - *self.prompt_mapping_meta.meta_args(buffer.size(0)), + *self.prompt_mapping_meta.meta_args( + buffer.size(0), self.lora_config.specialize_active_lora + ), add_inputs=True, ) y = y.view_as(y_org) @@ -316,8 +338,10 @@ def moe_lora_align_block_size( Aligns tokens and experts into block-sized chunks for LoRA-based mixture-of-experts (MoE) execution. """ - (token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args( - num_tokens + (token_lora_mapping, _, _, _, lora_ids, _, _) = ( + self.token_mapping_meta.meta_args( + num_tokens, self.lora_config.specialize_active_lora + ) ) if naive_block_assignment: expert_ids = topk_ids.reshape(-1) @@ -392,7 +416,10 @@ def add_lora_fused_moe( _, lora_ids, _, - ) = self.token_mapping_meta.meta_args(x.size(0)) + num_active_loras, + ) = self.token_mapping_meta.meta_args( + x.size(0), self.lora_config.specialize_active_lora + ) if token_lora_mapping is None: token_lora_mapping = token_lora_mapping_meta fused_moe_lora( @@ -408,6 +435,7 @@ def add_lora_fused_moe( max_lora_rank, top_k_num, lora_ids, + num_active_loras, adapter_enabled, shrink_config.get("BLOCK_SIZE_M", 64), shrink_config.get("BLOCK_SIZE_N", 64), diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 2840d5eda746..9b23d7e0c8b5 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -44,6 +44,25 @@ logger = init_logger(__name__) + +def get_captured_lora_counts(max_loras: int, specialize: bool) -> list[int]: + """ + Returns num_active_loras values for cudagraph capture. + + When specialize=True: powers of 2 up to max_loras, plus max_loras + 1. + When specialize=False: just [max_loras + 1]. + + This is the single source of truth for LoRA capture cases, used by both + CudagraphDispatcher and PunicaWrapperGPU. + """ + if not specialize: + return [max_loras + 1] + + return [ + n for n in range(1, max_loras + 2) if (n & (n - 1)) == 0 or n == max_loras + 1 + ] + + _GLOBAL_LORA_ID = 0 diff --git a/vllm/v1/cudagraph_dispatcher.py b/vllm/v1/cudagraph_dispatcher.py index f5738c6b3ca0..6f3e029c793b 100644 --- a/vllm/v1/cudagraph_dispatcher.py +++ b/vllm/v1/cudagraph_dispatcher.py @@ -5,6 +5,7 @@ from vllm.config import CUDAGraphMode, VllmConfig from vllm.forward_context import BatchDescriptor from vllm.logger import init_logger +from vllm.lora.utils import get_captured_lora_counts logger = init_logger(__name__) @@ -57,6 +58,11 @@ def __init__(self, vllm_config: VllmConfig): ) self.keys_initialized = False + self.specialize_lora_count = ( + self.vllm_config.lora_config.specialize_active_lora + if self.vllm_config.lora_config is not None + else False + ) # Default cudagraph_mode to NONE until initialize_cudagraph_keys is called self.cudagraph_mode = CUDAGraphMode.NONE @@ -92,8 +98,33 @@ def _compute_bs_to_padded_graph_size(self) -> None: "Use values from cudagraph_capture_sizes." ) + def _get_lora_cases(self) -> list[int]: + """ + Returns list of has_lora values for CUDA graph capture. + This is the single source of truth for LoRA capture cases. + """ + lora_config = self.vllm_config.lora_config + if lora_config is None: + # No LoRA configured - single case with no LoRA + return [0] + + # LoRA is enabled - capture graphs based on cudagraph_specialize_lora + if self.compilation_config.cudagraph_specialize_lora: + captured_counts = get_captured_lora_counts( + lora_config.max_loras, self.specialize_lora_count + ) + # Specialize: capture separate graphs for with and without LoRA + return [0] + captured_counts + else: + # No specialization: only capture graphs with LoRA active + return [lora_config.max_loras + 1] + def _create_padded_batch_descriptor( - self, num_tokens: int, uniform_decode: bool, has_lora: bool + self, + num_tokens: int, + uniform_decode: bool, + has_lora: bool, + num_active_loras: int = 0, ) -> BatchDescriptor: max_num_seqs = self.vllm_config.scheduler_config.max_num_seqs uniform_decode_query_len = self.uniform_decode_query_len @@ -111,6 +142,7 @@ def _create_padded_batch_descriptor( num_reqs=num_reqs, uniform=uniform_decode, has_lora=has_lora, + num_active_loras=num_active_loras, ) def add_cudagraph_key( @@ -135,26 +167,23 @@ def initialize_cudagraph_keys( self._compute_bs_to_padded_graph_size() - # LoRA activation cases to specialize the cuda graphs on - if self.vllm_config.lora_config: - if self.compilation_config.cudagraph_specialize_lora: - lora_cases = [True, False] - else: - lora_cases = [True] - else: - lora_cases = [False] + # Get LoRA cases to capture + lora_cases = self._get_lora_cases() + self.captured_lora_counts = [ + lora_count for lora_count in lora_cases if lora_count + ] # Note: we create all valid keys for cudagraph here but do not # guarantee all keys would be used. For example, if we allow lazy # capturing in future PR, some keys may never be triggered. if cudagraph_mode.mixed_mode() != CUDAGraphMode.NONE: - for bs, has_lora in product( + for bs, num_active_loras in product( self.compilation_config.cudagraph_capture_sizes, lora_cases ): self.add_cudagraph_key( cudagraph_mode.mixed_mode(), self._create_padded_batch_descriptor( - bs, False, has_lora + bs, False, num_active_loras > 0, num_active_loras ).relax_for_mixed_batch_cudagraphs(), ) @@ -173,10 +202,14 @@ def initialize_cudagraph_keys( for x in self.compilation_config.cudagraph_capture_sizes if x <= max_num_tokens and x >= uniform_decode_query_len ] - for bs, has_lora in product(cudagraph_capture_sizes_for_decode, lora_cases): + for bs, num_active_loras in product( + cudagraph_capture_sizes_for_decode, lora_cases + ): self.add_cudagraph_key( CUDAGraphMode.FULL, - self._create_padded_batch_descriptor(bs, True, has_lora), + self._create_padded_batch_descriptor( + bs, True, num_active_loras > 0, num_active_loras + ), ) self.keys_initialized = True @@ -187,6 +220,7 @@ def dispatch( uniform_decode: bool = False, has_lora: bool = False, disable_full: bool = False, + num_active_loras: int = 0, ) -> tuple[CUDAGraphMode, BatchDescriptor]: """ Given conditions(e.g.,batch descriptor and if using piecewise only), @@ -202,6 +236,7 @@ def dispatch( disable_full: If True, skip FULL cudagraph checks and return PIECEWISE or NONE only. (can be used for features like cascade attention that are not supported by full cudagraphs) + num_active_loras: Number of distinct active LoRA adapters. """ if ( not self.keys_initialized @@ -210,8 +245,24 @@ def dispatch( ): return CUDAGraphMode.NONE, BatchDescriptor(num_tokens) + effective_num_active_loras = num_active_loras + if has_lora and num_active_loras > 0: + if self.specialize_lora_count: + # Find the smallest captured `num_active_loras` that is >= the current + # `num_active_loras`. This is because we only capture graphs for + # a subset of possible `num_active_loras` values (powers of 2). + import bisect + + idx = bisect.bisect_left(self.captured_lora_counts, num_active_loras) + if idx < len(self.captured_lora_counts): + effective_num_active_loras = self.captured_lora_counts[idx] + else: + # When not specializing, graphs are captured only with max_loras + 1, + # so we must use max_loras + 1 for dispatch to find a matching graph. + effective_num_active_loras = self.vllm_config.lora_config.max_loras + 1 + batch_desc = self._create_padded_batch_descriptor( - num_tokens, uniform_decode, has_lora + num_tokens, uniform_decode, has_lora, effective_num_active_loras ) relaxed_batch_desc = batch_desc.relax_for_mixed_batch_cudagraphs() diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 061ac8680157..862f571bd8eb 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -3082,6 +3082,7 @@ def _determine_batch_execution_and_padding( # be improved in model runner v2) force_uniform_decode: bool | None = None, force_has_lora: bool | None = None, + force_num_active_loras: int | None = None, num_encoder_reqs: int = 0, ) -> tuple[ CUDAGraphMode, @@ -3103,11 +3104,13 @@ def _determine_batch_execution_and_padding( self.model_config.is_encoder_decoder and num_encoder_reqs > 0 ) - has_lora = ( - len(self.input_batch.lora_id_to_lora_request) > 0 - if force_has_lora is None - else force_has_lora + # Compute LoRA state for cudagraph dispatch + num_active_loras = ( + force_num_active_loras + if force_num_active_loras is not None + else len(self.input_batch.lora_id_to_lora_request) ) + has_lora = num_active_loras > 0 if force_has_lora is None else force_has_lora num_tokens_padded = self._pad_for_sequence_parallelism(num_tokens) dispatch_cudagraph = ( @@ -3116,6 +3119,7 @@ def _determine_batch_execution_and_padding( has_lora=has_lora, uniform_decode=uniform_decode, disable_full=disable_full, + num_active_loras=num_active_loras, ) if not force_eager else (CUDAGraphMode.NONE, BatchDescriptor(num_tokens_padded)) @@ -4606,8 +4610,8 @@ def _dummy_run( is_profile: bool = False, create_mixed_batch: bool = False, remove_lora: bool = True, - activate_lora: bool = False, is_graph_capturing: bool = False, + num_active_loras: int = 0, ) -> tuple[torch.Tensor, torch.Tensor]: """ Run a dummy forward pass to warm up/profile run or capture the @@ -4630,7 +4634,8 @@ def _dummy_run( create_mixed_batch: If True, create a mixed batch with both decode (1 token) and prefill (multiple tokens) requests. remove_lora: If False, dummy LoRAs are not destroyed after the run - activate_lora: If False, dummy_run is performed without LoRAs. + num_active_loras: Number of distinct active LoRAs to capture for. + LoRA is activated when num_active_loras > 0. """ mm_config = self.vllm_config.model_config.multimodal_config if mm_config and mm_config.mm_encoder_only: @@ -4712,7 +4717,10 @@ def _dummy_run( # `force_has_lora` is used for cudagraph capture; because LoRA is # activated later in the context manager, but we need to know the # LoRA state when determining the batch descriptor for capture - force_has_lora=activate_lora, + force_has_lora=num_active_loras > 0, + # `force_num_active_loras` is used for cudagraph capture; because we + # need to capture graphs for specific num_active_loras counts + force_num_active_loras=num_active_loras, ) ) @@ -4782,8 +4790,8 @@ def _dummy_run( self.lora_config, num_scheduled_tokens, num_sampled_tokens, - activate_lora, remove_lora, + num_active_loras, ): # Make sure padding doesn't exceed max_num_tokens assert num_tokens_padded <= self.max_num_tokens @@ -4884,7 +4892,10 @@ def _dummy_run( # lora cases when cudagraph_specialize_lora is enabled. This is a # short term mitigation for issue mentioned in # https://github.com/vllm-project/vllm/issues/28334 - if self.compilation_config.cudagraph_specialize_lora and activate_lora: + if ( + self.compilation_config.cudagraph_specialize_lora + and num_active_loras > 0 + ): use_cudagraphs = False self.drafter.dummy_run( @@ -5259,7 +5270,7 @@ def _capture_cudagraphs( # We skip EPLB here since we don't want to record dummy metrics for batch_desc in batch_descriptors: num_tokens = batch_desc.num_tokens - activate_lora = batch_desc.has_lora + num_active_loras = batch_desc.num_active_loras # We currently only capture ubatched graphs when its a FULL # cudagraph, a uniform decode batch, and the number of tokens @@ -5286,7 +5297,7 @@ def _capture_cudagraphs( num_tokens, cudagraph_runtime_mode=CUDAGraphMode.NONE, allow_microbatching=allow_microbatching, - activate_lora=activate_lora, + num_active_loras=num_active_loras, ) # Capture run @@ -5294,7 +5305,7 @@ def _capture_cudagraphs( num_tokens, cudagraph_runtime_mode=cudagraph_runtime_mode, allow_microbatching=allow_microbatching, - activate_lora=activate_lora, + num_active_loras=num_active_loras, is_graph_capturing=True, ) self.maybe_remove_all_loras(self.lora_config) diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index b7d488ea1c18..53873d156f88 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -133,11 +133,23 @@ def maybe_select_dummy_loras( num_scheduled_tokens: np.ndarray, mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE, num_sampled_tokens: np.ndarray | None = None, - activate_lora: bool = True, + num_active_loras: int = 0, ): + """ + Context manager to select dummy LoRAs for capture/warmup. + + Args: + lora_config: LoRA configuration, or None if LoRA is disabled. + num_scheduled_tokens: Array of scheduled token counts per request. + num_sampled_tokens: Array of sampled token counts per request. + num_active_loras: Number of distinct active LoRAs to use. + - 0: No LoRA active (set up zero mappings). + - >0: Use exactly this many distinct LoRAs. + """ if num_sampled_tokens is None: num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32) + # Skip LoRA setup entirely only if no LoRA config if lora_config is None: yield else: @@ -145,15 +157,52 @@ def maybe_select_dummy_loras( assert self.lora_manager is not None, "LoRA is not enabled" num_reqs = len(num_scheduled_tokens) - num_loras = lora_config.max_loras + max_loras = lora_config.max_loras + + # Determine how many distinct LoRAs to use and whether to include + # no-LoRA tokens (-1 entries). + # When num_active_loras > max_loras (e.g., max_loras + 1), we need + # to include -1 entries to simulate batches with both LoRA and + # no-LoRA tokens. This ensures prepare_tensors computes the correct + # num_active_loras that matches the cudagraph capture key. + if num_active_loras == 0: + # No LoRA active - use 0 mappings like the original code + effective_num_loras = 0 + include_no_lora = False + elif num_active_loras > max_loras: + # num_active_loras > max_loras means we want max_loras adapters + # PLUS no-LoRA tokens (-1). This is the max_loras + 1 case. + effective_num_loras = max_loras + include_no_lora = True + else: + # Specific number of active LoRAs requested + effective_num_loras = min(num_active_loras, max_loras) + include_no_lora = False # Make prompt lora mapping # Assign LoRA IDs cyclically to simulate a worst-case scenario. - if activate_lora: - prompt_lora_mapping = ( - np.arange(num_reqs, dtype=np.int32) % num_loras - ) + 1 + # LoRA IDs are 1-indexed (1 to max_loras) as required by LoRARequest. + # convert_mapping() will convert these to 0-indexed slot indices. + if effective_num_loras > 0: + if include_no_lora: + # Include -1 (no-LoRA) entries by cycling through + # -1, 1, 2, ..., effective_num_loras + # This ensures prepare_tensors sees both LoRA and no-LoRA + # tokens, computing num_active_loras = effective_num_loras+1 + cycle_values = np.array( + list(range(1, effective_num_loras + 1)), + dtype=np.int32, + ) + prompt_lora_mapping = cycle_values[ + np.arange(num_reqs, dtype=np.int32) % len(cycle_values) + ] + else: + # Use 1 to effective_num_loras (1-indexed lora IDs) + prompt_lora_mapping = ( + np.arange(num_reqs, dtype=np.int32) % effective_num_loras + ) + 1 else: + # No LoRA active - use 0 for all tokens (original behavior) prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32) # Make sample lora mapping @@ -162,14 +211,14 @@ def maybe_select_dummy_loras( # Make token lora mapping token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens) - # Make dummy lora requests + # Make dummy lora requests (only for the active LoRAs) lora_requests: set[LoRARequest] = { LoRARequest( lora_name=f"warmup_{lora_id}", lora_int_id=lora_id, lora_path="/not/a/real/path", ) - for lora_id in range(1, num_loras + 1) + for lora_id in range(1, effective_num_loras + 1) } self._set_active_loras( @@ -187,10 +236,21 @@ def maybe_dummy_run_with_lora( lora_config: LoRAConfig | None, num_scheduled_tokens: np.ndarray, num_sampled_tokens: np.ndarray, - activate_lora: bool = True, remove_lora: bool = True, + num_active_loras: int = 0, mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE, ): + """ + Context manager for dummy runs with LoRA. + + Args: + lora_config: LoRA configuration. + num_scheduled_tokens: Array of scheduled token counts per request. + num_sampled_tokens: Array of sampled token counts per request. + remove_lora: Whether to remove LoRAs after the context exits. + num_active_loras: Number of distinct active LoRAs to use. + LoRA is activated when num_active_loras > 0. + """ with ( self.maybe_setup_dummy_loras(lora_config, remove_lora), self.maybe_select_dummy_loras( @@ -198,7 +258,7 @@ def maybe_dummy_run_with_lora( num_scheduled_tokens, mapping_type, num_sampled_tokens, - activate_lora, + num_active_loras, ), ): yield