Skip to content
6 changes: 4 additions & 2 deletions tests/lora/test_fused_moe_lora_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ def use_fused_moe_lora_kernel(

# num_active_loras is the number of active LoRAs
# (max_loras + 1 to include no-lora case)
num_active_loras = max_loras + 1
# Stored as CPU tensor to match the kernel API (torch.compile compatibility)
num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu")

fused_moe_lora(
output,
Expand Down Expand Up @@ -389,7 +390,8 @@ def use_fused_moe_lora_kernel_naive(

# num_active_loras is the number of active LoRAs
# (max_loras + 1 to include no-lora case)
num_active_loras = max_loras + 1
# Stored as CPU tensor to match the kernel API (torch.compile compatibility)
num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu")

fused_moe_lora(
output,
Expand Down
7 changes: 6 additions & 1 deletion tests/lora/test_gptoss_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,12 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:


@pytest.mark.parametrize("mxfp4_use_marlin", [True, False])
@pytest.mark.parametrize("specialize_active_lora", [True, False])
def test_gpt_oss_lora(
monkeypatch: pytest.MonkeyPatch, gptoss20b_lora_files, mxfp4_use_marlin
monkeypatch: pytest.MonkeyPatch,
gptoss20b_lora_files,
mxfp4_use_marlin,
specialize_active_lora,
):
with monkeypatch.context() as m:
m.setenv("VLLM_MXFP4_USE_MARLIN", "1" if mxfp4_use_marlin else "0")
Expand All @@ -83,6 +87,7 @@ def test_gpt_oss_lora(
max_lora_rank=8,
max_num_seqs=2,
max_num_batched_tokens=2048,
specialize_active_lora=specialize_active_lora,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think OOM is due to this

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes @RunkaiTao we should add back this argument cudagraph_specialize_lora=False thanks!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it false now

compilation_config=vllm.config.CompilationConfig( # Avoid OOM
cudagraph_specialize_lora=False,
),
Expand Down
16 changes: 8 additions & 8 deletions vllm/lora/ops/triton_ops/fused_moe_lora_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device):


def _adjust_kernel_inputs(
num_active_loras: int,
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
sorted_token_ids: torch.Tensor | None,
expert_ids: torch.Tensor,
):
Expand All @@ -109,7 +109,7 @@ def _adjust_kernel_inputs(
else:
stride_tl = sorted_token_ids.stride(0)
stride_el = expert_ids.stride(0)
grid_lora_dim = num_active_loras
grid_lora_dim = num_active_loras.item()
return grid_lora_dim, stride_tl, stride_el


Expand Down Expand Up @@ -354,7 +354,7 @@ def _fused_moe_lora_shrink(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
mul_routed_weight: bool = False,
use_gdc: bool = False,
) -> None:
Expand Down Expand Up @@ -458,7 +458,7 @@ def _fused_moe_lora_expand(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
mul_routed_weight: bool = False,
offset: int = 0,
use_gdc: bool = False,
Expand Down Expand Up @@ -559,7 +559,7 @@ def _fused_moe_lora(
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
num_active_loras: int,
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
adapter_enabled: torch.Tensor,
shrink_block_size_m: int,
shrink_block_size_n: int,
Expand Down Expand Up @@ -719,7 +719,7 @@ def _fused_moe_lora_fake(
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
num_active_loras: int,
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
adapter_enabled: torch.Tensor,
shrink_block_size_m: int,
shrink_block_size_n: int,
Expand Down Expand Up @@ -769,7 +769,7 @@ def _fused_moe_lora_shrink_fake(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
mul_routed_weight: bool = False,
use_gdc: bool = False,
) -> None:
Expand Down Expand Up @@ -805,7 +805,7 @@ def _fused_moe_lora_expand_fake(
num_warps: int,
num_stages: int,
split_k: int,
num_active_loras: int,
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
mul_routed_weight: bool = False,
offset: int = 0,
use_gdc: bool = False,
Expand Down
6 changes: 3 additions & 3 deletions vllm/lora/ops/triton_ops/lora_expand_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def _lora_expand(
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit : can you add # CPU tensor [1], number of active LoRAs comment to other places where num_active_loras: torch.Tensor is an arg please.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comments are added.

offset_start: int = 0,
add_inputs: bool = False,
) -> None:
Expand Down Expand Up @@ -235,7 +235,7 @@ def _lora_expand(
grid = (
triton.cdiv(M, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N),
NUM_SLICES,
num_active_loras,
num_active_loras.item(),
)
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance.
Expand Down Expand Up @@ -289,7 +289,7 @@ def _lora_expand_fake(
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
num_active_loras: int,
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
offset_start: int = 0,
add_inputs: bool = False,
) -> None:
Expand Down
42 changes: 30 additions & 12 deletions vllm/lora/ops/triton_ops/lora_kernel_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,16 @@ class LoRAKernelMeta:
# to early exit from inside the lora_expand / lora_shrink torch operation.
no_lora_flag_cpu: torch.Tensor

# Number of active LoRAs (unique non-(-1) values in token_lora_mapping)
# Stored as a Python int to avoid GPU->CPU sync during forward pass
num_active_loras: int = 0
# Number of active LoRAs (unique non-(-1) values in token_lora_mapping).
# Stored as a CPU tensor (not a Python int) so that torch.compile treats
# it as a dynamic value rather than baking it as a constant at trace time.
# This follows the same pattern as no_lora_flag_cpu above.
num_active_loras_cpu: torch.Tensor

# Default num_active_loras value (max_loras + 1) as a CPU tensor,
# used when specialize_active_lora is False to avoid allocating a
# new tensor on every meta_args() call.
default_num_active_loras_cpu: torch.Tensor

# Captured LoRA counts for cudagraph specialization (sorted list).
# When specialize_active_lora is enabled, num_active_loras is rounded up
Expand Down Expand Up @@ -73,13 +80,20 @@ def make(

no_lora_flag_cpu = torch.tensor([False], dtype=torch.bool, device="cpu")

num_active_loras_cpu = torch.tensor([0], dtype=torch.int32, device="cpu")
default_num_active_loras_cpu = torch.tensor(
[max_loras + 1], dtype=torch.int32, device="cpu"
)

return LoRAKernelMeta(
token_lora_mapping=token_lora_mapping,
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
active_lora_ids=active_lora_ids,
num_tokens_per_lora=num_tokens_per_lora,
lora_token_start_loc=lora_token_start_loc,
no_lora_flag_cpu=no_lora_flag_cpu,
num_active_loras_cpu=num_active_loras_cpu,
default_num_active_loras_cpu=default_num_active_loras_cpu,
captured_lora_counts=sorted(captured_lora_counts)
if captured_lora_counts
else [],
Expand All @@ -90,8 +104,7 @@ def _reset(self):
self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0)
self.no_lora_flag_cpu.fill_(False)
self.num_active_loras = 0
self.captured_lora_counts = []
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is self.captured_lora_counts removed by mistake ?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes @RunkaiTao don't we need this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.captured_lora_counts is a configuration value (set once at init time to define cudagraph capture keys). But _reset function is called at the start of every prepare_tensors(). self.captured_lora_counts should not be reset here, it's not per-batch tensors like active_lora_ids, num_tokens_per_lora, etc.

self.num_active_loras_cpu.fill_(0)

def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
"""
Expand Down Expand Up @@ -137,14 +150,16 @@ def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
num_tokens_per_lora, non_blocking=True
)

self.num_active_loras = lora_ids.size(0)
num_active_loras = lora_ids.size(0)

# Round up num_active_loras to match cudagraph capture keys.
# This ensures the kernel grid dimension matches the captured graph.
if self.captured_lora_counts and self.num_active_loras > 0:
idx = bisect.bisect_left(self.captured_lora_counts, self.num_active_loras)
if self.captured_lora_counts and num_active_loras > 0:
idx = bisect.bisect_left(self.captured_lora_counts, num_active_loras)
if idx < len(self.captured_lora_counts):
self.num_active_loras = self.captured_lora_counts[idx]
num_active_loras = self.captured_lora_counts[idx]

self.num_active_loras_cpu[0] = num_active_loras

# lora_token_start_loc
lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0)
Expand All @@ -163,7 +178,7 @@ def meta_args(
torch.Tensor,
torch.Tensor,
torch.Tensor,
int,
torch.Tensor,
]:
"""
This function returns the kernel metadata required for the current
Expand All @@ -175,13 +190,16 @@ def meta_args(
token_nums (int): Number of input tokens in the current forward
pass of the kernel.
"""
max_loras = self.active_lora_ids.size(0) - 1
if specialize_active_lora:
num_active_loras = self.num_active_loras_cpu
else:
num_active_loras = self.default_num_active_loras_cpu
return (
self.token_lora_mapping[:token_nums],
self.token_indices_sorted_by_lora_ids[:token_nums],
self.num_tokens_per_lora,
self.lora_token_start_loc,
self.active_lora_ids,
self.no_lora_flag_cpu,
self.num_active_loras if specialize_active_lora else max_loras + 1,
num_active_loras,
)
9 changes: 6 additions & 3 deletions vllm/lora/ops/triton_ops/lora_shrink_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def _lora_shrink(
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
num_active_loras: int, # number of active LoRAs (unused here, for API compat)
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
scaling: float,
) -> None:
"""
Expand All @@ -157,6 +157,9 @@ def _lora_shrink(
lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
num_active_loras (torch.Tensor): A CPU tensor of size 1, containing the
number of active LoRAs. Stored as a tensor (not int) so
torch.compile treats it as dynamic rather than a constant.
scaling (float): Scaling factor.
"""

Expand Down Expand Up @@ -215,7 +218,7 @@ def _lora_shrink(
grid = (
SPLIT_K * triton.cdiv(M, BLOCK_M) * triton.cdiv(N, BLOCK_N),
NUM_SLICES,
num_active_loras,
num_active_loras.item(),
)
# We disable PDL temporarily because LoRA kernels are not launching back-to-back,
# making PDL invalid and affecting the kernel performance.
Expand Down Expand Up @@ -267,7 +270,7 @@ def _lora_shrink_fake(
lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
num_active_loras: int,
num_active_loras: torch.Tensor, # CPU tensor [1], number of active LoRAs
scaling: float,
) -> None:
return
Expand Down
1 change: 1 addition & 0 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5379,6 +5379,7 @@ def _capture_cudagraphs(
# if we want to warm up attention or not. This is
# different from the case where `FULL` implies capture
# attention while `PIECEWISE` implies no attention.

dummy_run(
num_tokens,
cudagraph_runtime_mode=CUDAGraphMode.NONE,
Expand Down