Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ac26372
Init
jeejeelee Apr 26, 2026
b39f3a8
Merge remote-tracking branch 'origin/main' into fused-moe-triton-kernel
jeejeelee Apr 27, 2026
977f1a0
Move
jeejeelee Apr 27, 2026
778ea81
Merge remote-tracking branch 'origin/main' into fused-moe-triton-kernel
jeejeelee May 1, 2026
1f26ac7
Support dual streams
jeejeelee May 1, 2026
784a97b
Merge remote-tracking branch 'origin/main' into fused-moe-triton-kernel
jeejeelee May 7, 2026
ce0e0c8
Add glm config
jeejeelee May 8, 2026
c787eb3
Support samll batch fused kernel
jeejeelee May 11, 2026
e71fa73
Address conflict
jeejeelee May 11, 2026
9b1b132
Move
jeejeelee May 12, 2026
b1d60d9
Move
jeejeelee May 14, 2026
0a2fa4f
FIX
jeejeelee May 15, 2026
ece684a
Merge branch 'main' into fused-moe-triton-kernel
jeejeelee May 15, 2026
21460c6
Merge branch 'main' into fused-moe-triton-kernel
jeejeelee May 16, 2026
e4ec173
Shrink tests
jeejeelee May 16, 2026
2bf7071
Merge branch 'main' into fused-moe-triton-kernel
jeejeelee May 18, 2026
8dd901c
Merge branch 'main' into fused-moe-triton-kernel
jeejeelee May 20, 2026
5bd9ea9
FMT
jeejeelee May 20, 2026
135ef1f
Merge branch 'main' into fused-moe-triton-kernel
jeejeelee May 20, 2026
4d9174d
Merge branch 'main' into fused-moe-triton-kernel
jeejeelee May 20, 2026
d5754b0
Cleanup
jeejeelee May 20, 2026
7ed7946
Merge branch 'main' into fused-moe-triton-kernel
jeejeelee May 22, 2026
c9f0c89
Merge branch 'main' into fused-moe-triton-kernel
jeejeelee May 22, 2026
5a96756
Merge branch 'main' into fused-moe-triton-kernel
jeejeelee May 25, 2026
f1b6011
Merge branch 'main' into fused-moe-triton-kernel
ywang96 May 25, 2026
eab84fd
Merge branch 'main' into fused-moe-triton-kernel
jeejeelee May 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
465 changes: 465 additions & 0 deletions benchmarks/kernels/benchmark_fused_moe_lora_one_shot.py

Large diffs are not rendered by default.

495 changes: 495 additions & 0 deletions tests/lora/test_fused_moe_lora_kernel.py

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1971,6 +1971,8 @@ def _resolve_rust_frontend_path() -> str | None:
int(os.getenv("VLLM_USE_SIMPLE_KV_OFFLOAD", "0"))
),
# Whether to enable dual cuda streams for LoRA computation
# (used by both BaseLinearLayerWithLoRA and FusedMoEWithLoRA to
# overlap the base layer compute with the LoRA fast path).
"VLLM_LORA_ENABLE_DUAL_STREAM": lambda: bool(
int(os.getenv("VLLM_LORA_ENABLE_DUAL_STREAM", "0"))
),
Expand Down
9 changes: 1 addition & 8 deletions vllm/lora/layers/base_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,9 @@
from vllm.utils.torch_utils import direct_register_custom_op

from .base import BaseLayerWithLoRA
from .utils import _get_lora_device
from .utils import _get_lora_aux_cuda_stream, _get_lora_device

if envs.VLLM_LORA_ENABLE_DUAL_STREAM:
_lora_aux_cuda_stream: torch.cuda.Stream | None = None

def _get_lora_aux_cuda_stream() -> torch.cuda.Stream | None:
global _lora_aux_cuda_stream
if _lora_aux_cuda_stream is None and current_platform.is_cuda_alike():
_lora_aux_cuda_stream = torch.cuda.Stream()
return _lora_aux_cuda_stream

def lora_linear_async(
layer_name: str,
Expand Down
26 changes: 25 additions & 1 deletion vllm/lora/layers/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
MoEPrepareAndFinalizeNoDPEPModular,
)
from vllm.platforms import current_platform

from .utils import _get_lora_device
from .utils import _get_lora_aux_cuda_stream, _get_lora_device


class FusedMoEWithLoRA(BaseLayerWithLoRA):
Expand All @@ -34,6 +35,9 @@ def __init__(self, base_layer: FusedMoE) -> None:
self.tp_size = self.base_layer.tp_size
self.tp_rank = self.base_layer.tp_rank
self.device = _get_lora_device(base_layer)

self._enable_aux_cuda_stream = envs.VLLM_LORA_ENABLE_DUAL_STREAM
self._init_lora_stream_context()
# For non-gated MoE (is_act_and_mul=False), only 1 slice is needed
# since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3)
self._w13_slices = 2 if base_layer.moe_config.is_act_and_mul else 1
Expand Down Expand Up @@ -65,7 +69,25 @@ def __init__(self, base_layer: FusedMoE) -> None:
FusedMoEModularMethod(self.base_layer.quant_method, moe_kernel)
)

def _init_lora_stream_context(self) -> None:
self._lora_stream: torch.cuda.Stream | None = None
self._events: tuple[torch.cuda.Event, ...] | None = None
if not self._enable_aux_cuda_stream:
return
if not current_platform.is_cuda_alike():
return
self._lora_stream = _get_lora_aux_cuda_stream()
# 4 events: 2 per (base GEMM, LoRA) pair so w13 and w2 don't reuse
# the same event objects; reuse-within-a-pair is fine because the
# second pair starts only after intermediate_cache1.add_() has joined.
self._events = tuple(torch.cuda.Event() for _ in range(4))

def _build_lora_context(self):
use_dual_stream = (
self._enable_aux_cuda_stream
and not self.fully_sharded
and self._lora_stream is not None
)
Comment on lines +86 to +90
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The one-shot Triton kernel has a hard limit of max_lora_rank <= 128. For ranks exceeding this, the dual-stream path (which relies on the one-shot kernel's add_inputs=False contract) should be disabled to avoid a crash in the forward pass. This ensures that high-rank LoRAs fall back to the sequential legacy path which supports arbitrary ranks.

        use_dual_stream = (
            self._enable_aux_cuda_stream
            and not self.fully_sharded
            and self._lora_stream is not None
            and self.max_lora_rank <= 128
        )

return MoELoRAContext(
w13_lora_a_stacked=self.w13_lora_a_stacked,
w13_lora_b_stacked=self.w13_lora_b_stacked,
Expand All @@ -81,6 +103,8 @@ def _build_lora_context(self):
local_num_experts=self.base_layer.local_num_experts,
punica_wrapper=self.punica_wrapper,
use_tuned_config=bool(envs.VLLM_TUNED_CONFIG_FOLDER),
aux_stream=self._lora_stream if use_dual_stream else None,
events=self._events if use_dual_stream else None,
)

def _create_lora_a_weights(
Expand Down
13 changes: 13 additions & 0 deletions vllm/lora/layers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,22 @@
import torch
import torch.nn as nn

from vllm import envs
from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config
from vllm.platforms import current_platform
from vllm.utils.math_utils import next_power_of_2

_lora_aux_cuda_stream: torch.cuda.Stream | None = None


def _get_lora_aux_cuda_stream() -> torch.cuda.Stream | None:
if not envs.VLLM_LORA_ENABLE_DUAL_STREAM:
return None
global _lora_aux_cuda_stream
if _lora_aux_cuda_stream is None and current_platform.is_cuda_alike():
_lora_aux_cuda_stream = torch.cuda.Stream()
return _lora_aux_cuda_stream


class LoRAMappingType(Enum):
LANGUAGE = 1
Expand Down
Loading
Loading