From ac2637240f30792a7e595bc0e94378e88fb2aaa0 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sun, 26 Apr 2026 04:43:03 +0000 Subject: [PATCH 01/10] Init Signed-off-by: Jee Jee Li --- .../benchmark_fused_moe_lora_one_shot.py | 436 +++++++++++++ tests/lora/test_fused_moe_lora_kernel.py | 574 ++++++++++++++++++ vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 390 ++++++++++++ 3 files changed, 1400 insertions(+) create mode 100644 benchmarks/kernels/benchmark_fused_moe_lora_one_shot.py diff --git a/benchmarks/kernels/benchmark_fused_moe_lora_one_shot.py b/benchmarks/kernels/benchmark_fused_moe_lora_one_shot.py new file mode 100644 index 000000000000..6193bb9fa029 --- /dev/null +++ b/benchmarks/kernels/benchmark_fused_moe_lora_one_shot.py @@ -0,0 +1,436 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Benchmark the fused MoE-LoRA fast path (one-shot) vs two-kernel baseline. + +The "one_shot" provider goes through `vllm.lora.ops.triton_ops.fused_moe_lora` +which dispatches to the single-kernel one-shot implementation when +fully_sharded=False (the prefill default). + +The "two_kernel" provider drives `fused_moe_lora_shrink` + `fused_moe_lora_expand` +directly, bypassing the dispatch and matching the legacy two-kernel path's +work distribution. This isolates the win from kernel fusion. + +Run: + .venv/bin/python -m benchmarks.kernels.benchmark_fused_moe_lora_one_shot + .venv/bin/python -m benchmarks.kernels.benchmark_fused_moe_lora_one_shot \\ + --model qwen3moe +""" + +from __future__ import annotations + +import argparse +import random + +import torch + +from vllm import _custom_ops as ops +from vllm.lora.ops.triton_ops import ( + fused_moe_lora, + fused_moe_lora_expand, + fused_moe_lora_shrink, +) +from vllm.triton_utils import triton + +DTYPE = torch.bfloat16 +DEVICE = "cuda" + + +# ----- input fabrication ----------------------------------------------------- + + +def _round_up(x: int, base: int) -> int: + return ((x + base - 1) // base) * base + + +def _ceildiv(x: int, y: int) -> int: + return (x + y - 1) // y + + +def _assign_loras(num_tokens: int, num_sequences: int, max_loras: int) -> torch.Tensor: + tokens_per_seq = num_tokens // num_sequences + rem = num_tokens % num_sequences + out = torch.empty(num_tokens, dtype=torch.int32) + start = 0 + for i in range(num_sequences): + end = start + tokens_per_seq + (1 if i < rem else 0) + out[start:end] = random.randint(0, max_loras - 1) + start = end + return out + + +def _assign_experts(num_tokens: int, num_experts: int, top_k: int): + expert_indices = torch.empty((num_tokens, top_k), dtype=torch.int32) + for i in range(num_tokens): + expert_indices[i] = torch.randperm(num_experts)[:top_k] + weights = torch.rand((num_tokens, top_k), dtype=torch.float32) + weights = weights / weights.sum(dim=1, keepdim=True) + return expert_indices, weights + + +def _make_inputs( + M: int, + K: int, + N_per_slice: int, + rank: int, + num_experts: int, + top_k: int, + max_loras: int, + num_slices: int, + block_size_m: int, +): + """Mirrors the production caller's tensor layout.""" + torch.manual_seed(0) + random.seed(0) + + num_sequences = max(1, min(M, 8)) + topk_ids_cpu, topk_weights_cpu = _assign_experts(M, num_experts, top_k) + token_lora_cpu = _assign_loras(M, num_sequences, max_loras) + lora_ids_cpu = torch.full((max_loras + 1,), -1, dtype=torch.int32) + uniq = torch.unique(token_lora_cpu, sorted=True) + lora_ids_cpu[: uniq.size(0)].copy_(uniq) + + topk_ids = topk_ids_cpu.to(DEVICE) + topk_weights = topk_weights_cpu.to(device=DEVICE, dtype=DTYPE) + token_lora_mapping = token_lora_cpu.to(DEVICE) + lora_ids = lora_ids_cpu.to(DEVICE) + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32, device=DEVICE) + + lora_a = [ + torch.randn((max_loras, num_experts, rank, K), dtype=DTYPE, device=DEVICE) + / max(K, 1) ** 0.5 + for _ in range(num_slices) + ] + lora_b = [ + torch.randn( + (max_loras, num_experts, N_per_slice, rank), + dtype=DTYPE, + device=DEVICE, + ) + / max(rank, 1) ** 0.5 + for _ in range(num_slices) + ] + hidden = torch.randn((M, K), dtype=DTYPE, device=DEVICE) + out_template = torch.zeros( + (M, top_k, num_slices * N_per_slice), dtype=DTYPE, device=DEVICE + ) + + # Sorted-path metadata (the prefill default). + max_pad = topk_ids.numel() + num_experts * (block_size_m - 1) + max_pad = _round_up(max_pad, block_size_m) + max_blocks = _ceildiv(max_pad, block_size_m) + sorted_token_ids = torch.empty( + (max_loras * max_pad,), dtype=torch.int32, device=DEVICE + ) + expert_ids = torch.empty( + (max_loras * max_blocks,), dtype=torch.int32, device=DEVICE + ) + num_post = torch.empty((max_loras,), dtype=torch.int32, device=DEVICE) + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size_m, + max_loras, + max_pad, + max_blocks, + sorted_token_ids, + expert_ids, + num_post, + adapter_enabled, + lora_ids, + ) + expert_ids = expert_ids.view(max_loras, -1).contiguous() + sorted_token_ids = sorted_token_ids.view(max_loras, -1).contiguous() + num_active = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") + + return dict( + hidden=hidden, + lora_a=lora_a, + lora_b=lora_b, + topk_weights=topk_weights, + sorted_token_ids=sorted_token_ids, + expert_ids=expert_ids, + num_post=num_post, + token_lora_mapping=token_lora_mapping, + lora_ids=lora_ids, + num_active=num_active, + adapter_enabled=adapter_enabled, + out_template=out_template, + # bookkeeping + M=M, + K=K, + N_per_slice=N_per_slice, + rank=rank, + num_experts=num_experts, + top_k=top_k, + max_loras=max_loras, + num_slices=num_slices, + block_size_m=block_size_m, + ) + + +# ----- providers ------------------------------------------------------------- + + +def _run_one_shot(inp: dict): + """Drive `fused_moe_lora` with fully_sharded=False -> one-shot fast path.""" + out = inp["out_template"].clone() + fused_moe_lora( + out, + inp["hidden"], + inp["lora_a"], + inp["lora_b"], + inp["topk_weights"], + inp["sorted_token_ids"], + inp["expert_ids"], + inp["num_post"], + inp["token_lora_mapping"], + inp["rank"], + inp["top_k"], + inp["lora_ids"], + inp["num_active"], + inp["adapter_enabled"], + inp["block_size_m"], + 64, + 32, + 8, + 4, + 3, + 1, + inp["block_size_m"], + 64, + 32, + 8, + 4, + 3, + 1, + False, + False, + 0, + ) + return out + + +def _run_two_kernel(inp: dict): + """Drive `fused_moe_lora_shrink` + `fused_moe_lora_expand` directly, + bypassing the dispatch. Matches the legacy two-kernel work distribution. + """ + M = inp["M"] + top_k = inp["top_k"] + rank = inp["rank"] + num_slices = inp["num_slices"] + N_per_slice = inp["N_per_slice"] + K = inp["K"] + num_experts = inp["num_experts"] + block_m = inp["block_size_m"] + + intermediate = torch.zeros((num_slices, M, top_k, rank), dtype=DTYPE, device=DEVICE) + out = inp["out_template"].clone() + EM = inp["sorted_token_ids"].shape[1] + num_tokens = M * top_k + + fused_moe_lora_shrink( + intermediate, + inp["hidden"], + inp["lora_a"], + inp["topk_weights"], + inp["sorted_token_ids"], + inp["expert_ids"], + inp["num_post"], + inp["token_lora_mapping"], + top_k, + inp["lora_ids"], + inp["adapter_enabled"], + torch.device(DEVICE), + rank, + M, + EM, + K, + num_tokens, + num_experts, + num_slices, + block_m, + 64, + 32, + 8, + 4, + 3, + 1, + inp["num_active"], + False, + ) + fused_moe_lora_expand( + out, + intermediate, + inp["lora_b"], + inp["topk_weights"], + inp["sorted_token_ids"], + inp["expert_ids"], + inp["num_post"], + inp["token_lora_mapping"], + top_k, + inp["lora_ids"], + inp["adapter_enabled"], + torch.device(DEVICE), + rank, + M, + EM, + K, + num_tokens, + num_experts, + num_slices, + rank, + N_per_slice, + block_m, + 64, + 32, + 8, + 4, + 3, + 1, + inp["num_active"], + False, + 0, + ) + return out + + +PROVIDER_FNS = { + "one_shot": _run_one_shot, + "two_kernel": _run_two_kernel, +} + + +# ----- model presets --------------------------------------------------------- + + +MODEL_PRESETS: dict[str, dict] = { + # Mixtral-8x7B style: E=8, top_k=2, hidden=4096, intermediate=14336 + "mixtral": dict( + K=4096, + N_per_slice=7168, + num_experts=8, + top_k=2, + max_loras=4, + num_slices=2, + block_size_m=64, + ), + # Qwen3-MoE / DeepSeek-V2 style: E=64, top_k=8, hidden=2048, inter=1408 + "qwen3moe": dict( + K=2048, + N_per_slice=1408, + num_experts=64, + top_k=8, + max_loras=4, + num_slices=2, + block_size_m=64, + ), +} + + +M_RANGE = [16, 64, 256, 1024, 4096, 16384] +RANK_RANGE = [8, 16, 32, 64] + + +def get_benchmark(model: str): + preset = MODEL_PRESETS[model] + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["M", "rank"], + x_vals=[(M, R) for M in M_RANGE for R in RANK_RANGE], + line_arg="provider", + line_vals=list(PROVIDER_FNS.keys()), + line_names=["one_shot (fused)", "two_kernel (legacy)"], + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"fused_moe_lora-{model}", + args={"preset": preset}, + ) + ) + def benchmark(M, rank, provider, preset): + inp = _make_inputs( + M=M, + K=preset["K"], + N_per_slice=preset["N_per_slice"], + rank=rank, + num_experts=preset["num_experts"], + top_k=preset["top_k"], + max_loras=preset["max_loras"], + num_slices=preset["num_slices"], + block_size_m=preset["block_size_m"], + ) + fn = PROVIDER_FNS[provider] + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: fn(inp), quantiles=quantiles + ) + return ms, max_ms, min_ms + + return benchmark + + +# ----- correctness sanity --------------------------------------------------- + + +def calculate_diff(model: str, M: int, rank: int): + preset = MODEL_PRESETS[model] + inp = _make_inputs( + M=M, + K=preset["K"], + N_per_slice=preset["N_per_slice"], + rank=rank, + num_experts=preset["num_experts"], + top_k=preset["top_k"], + max_loras=preset["max_loras"], + num_slices=preset["num_slices"], + block_size_m=preset["block_size_m"], + ) + out_one = _run_one_shot(inp) + out_two = _run_two_kernel(inp) + max_abs = (out_one.float() - out_two.float()).abs().max().item() + print( + f" model={model:<9} M={M:<6} rank={rank:<3} " + f"max|one_shot - two_kernel|={max_abs:.4g} " + f"ref|max|={out_two.float().abs().max().item():.3g}" + ) + if max_abs <= 5e-2: + print(" ✅ outputs match within bf16 tolerance") + else: + print(" ❌ outputs differ beyond expected bf16 noise") + + +# ----- main ------------------------------------------------------------------ + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + default="mixtral", + choices=list(MODEL_PRESETS.keys()), + help="Model preset to sweep", + ) + parser.add_argument( + "--save-path", + type=str, + default="./configs/fused_moe_lora_one_shot/", + help="Directory to save benchmark results", + ) + parser.add_argument( + "--check-only", + action="store_true", + help="Run correctness sanity check only, no perf sweep", + ) + args = parser.parse_args() + + print(f"Correctness check ({args.model}):") + calculate_diff(args.model, M=256, rank=32) + if args.check_only: + raise SystemExit(0) + + print(f"\nGPU: {torch.cuda.get_device_name()}") + print(f"Model preset: {args.model}\n") + benchmark = get_benchmark(args.model) + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index 8adc20865755..720db2f788fa 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -744,3 +744,577 @@ def _get_shard_slice(shard_size): output = tensor_model_parallel_all_reduce(output) torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2) + + +# -- one-shot fast-path coverage -------------------------------------------- +# The fused shrink+expand one-shot kernel pads `BLOCK_R` to next_pow2(rank), +# with a floor of 16 (tensor-core minimum). Small ranks (4, 8) exercise the +# rank-dim masking and are not covered by the original tests, which start at +# rank=16. The legacy two-kernel path additionally fails on rank=4 in TMA +# mode because the rank-dim stride (rank * elem_size) is not 16-byte +# aligned; the one-shot fast path takes precedence whenever fully_sharded +# is False so this regression is hidden in normal use, but the test still +# ensures the one-shot logic is correct against the pytorch reference. + + +@pytest.mark.parametrize("num_tokens", [16, 100]) +@pytest.mark.parametrize("top_k_num", [2]) +@pytest.mark.parametrize("num_experts", [8, 64]) +@pytest.mark.parametrize("max_loras", [4]) +@pytest.mark.parametrize("N", [1408]) +@pytest.mark.parametrize("K", [2048]) +@pytest.mark.parametrize("max_lora_rank", [4, 8]) +@pytest.mark.parametrize("block_size", [16, 64]) +@pytest.mark.parametrize("num_slices", [1, 2]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +def test_fused_moe_lora_kernel_small_rank( + num_tokens, + top_k_num, + num_experts, + max_loras, + N, + K, + max_lora_rank, + block_size, + num_slices, + dtype, + device, + seed, +): + """One-shot fast path covering rank<16 (padded to BLOCK_R=16 inside kernel).""" + torch.set_default_device(device) + set_random_seed(seed) + num_sequences = max(1, min(num_tokens, 8)) + topk_ids, topk_weights, token_lora_mapping, lora_ids = sample_data( + num_tokens, num_sequences, max_loras, num_experts, top_k_num + ) + + lora_a_stacked = [ + torch.rand( + (max_loras, num_experts, max_lora_rank, K), + dtype=dtype, + ) + for _ in range(num_slices) + ] + lora_b_stacked = [ + torch.rand( + (max_loras, num_experts, N // num_slices, max_lora_rank), + dtype=dtype, + ) + for _ in range(num_slices) + ] + hidden_states = torch.rand((num_tokens, K), dtype=dtype) + + output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype) + use_fused_moe_lora_kernel( + topk_ids, + topk_weights, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_ids, + lora_a_stacked, + lora_b_stacked, + hidden_states, + output, + max_loras, + num_experts, + block_size, + ) + output_ref = use_torch( + hidden_states, + token_lora_mapping, + topk_ids, + lora_a_stacked, + lora_b_stacked, + top_k_num, + num_slices, + ) + torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("num_tokens", [16, 64]) +@pytest.mark.parametrize("top_k_num", [2]) +@pytest.mark.parametrize("num_experts", [8]) +@pytest.mark.parametrize("max_loras", [4]) +@pytest.mark.parametrize("N", [2048]) +@pytest.mark.parametrize("K", [4096]) +@pytest.mark.parametrize("max_lora_rank", [8, 16, 32, 64]) +@pytest.mark.parametrize("block_size", [64]) +@pytest.mark.parametrize("num_slices", [2]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +def test_fused_moe_lora_kernel_npid_path( + num_tokens, + top_k_num, + num_experts, + max_loras, + N, + K, + max_lora_rank, + block_size, + num_slices, + dtype, + device, + seed, +): + """Exercise the small-batch / NPID > 1 branch of the one-shot fast path. + + With these sizes the one-shot wrapper computes NPID_FACTOR > 1 (base CTA count + < SM count), so each program covers only an outer chunk of N. The + cross-outer-block write mask is the correctness-critical bit. + """ + torch.set_default_device(device) + set_random_seed(seed) + num_sequences = max(1, min(num_tokens, 4)) + topk_ids, topk_weights, token_lora_mapping, lora_ids = sample_data( + num_tokens, num_sequences, max_loras, num_experts, top_k_num + ) + + lora_a_stacked = [ + torch.rand( + (max_loras, num_experts, max_lora_rank, K), + dtype=dtype, + ) + for _ in range(num_slices) + ] + lora_b_stacked = [ + torch.rand( + (max_loras, num_experts, N // num_slices, max_lora_rank), + dtype=dtype, + ) + for _ in range(num_slices) + ] + hidden_states = torch.rand((num_tokens, K), dtype=dtype) + + output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype) + use_fused_moe_lora_kernel( + topk_ids, + topk_weights, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_ids, + lora_a_stacked, + lora_b_stacked, + hidden_states, + output, + max_loras, + num_experts, + block_size, + ) + output_ref = use_torch( + hidden_states, + token_lora_mapping, + topk_ids, + lora_a_stacked, + lora_b_stacked, + top_k_num, + num_slices, + ) + torch.testing.assert_close(output, output_ref, atol=2e-2, rtol=2e-2) + + +# -- one-shot corner-case coverage ------------------------------------------ +# Each of the following exercises a path where the kernel is launched but +# every program early-exits, leaving the output unchanged. The contract is +# additive (`output += contribution`), so an empty contribution must leave +# the input residual untouched. + + +def _build_one_shot_inputs( + num_tokens, + top_k_num, + num_experts, + max_loras, + max_lora_rank, + K, + N, + num_slices, + block_size, + dtype, +): + """Common scaffolding for the corner-case tests below.""" + num_sequences = max(1, min(num_tokens, 4)) if num_tokens > 0 else 1 + if num_tokens > 0: + topk_ids, topk_weights, token_lora_mapping, lora_ids = sample_data( + num_tokens, num_sequences, max_loras, num_experts, top_k_num + ) + else: + # M=0 path: caller may still hand us empty tensors with the right shape. + topk_ids = torch.empty((0, top_k_num), dtype=torch.int32) + topk_weights = torch.empty((0, top_k_num), dtype=torch.float32) + token_lora_mapping = torch.empty((0,), dtype=torch.int32) + lora_ids = torch.full((max_loras + 1,), -1, dtype=torch.int32) + + lora_a_stacked = [ + torch.rand((max_loras, num_experts, max_lora_rank, K), dtype=dtype) + for _ in range(num_slices) + ] + lora_b_stacked = [ + torch.rand( + (max_loras, num_experts, N // num_slices, max_lora_rank), dtype=dtype + ) + for _ in range(num_slices) + ] + hidden_states = torch.rand((max(num_tokens, 0), K), dtype=dtype) + return ( + topk_ids, + topk_weights.to(dtype), + token_lora_mapping, + lora_ids, + lora_a_stacked, + lora_b_stacked, + hidden_states, + ) + + +def _call_one_shot( + output, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_ids, + num_active_loras, + adapter_enabled, + block_size, +): + """Direct call into fused_moe_lora with one-shot-routed defaults.""" + from vllm.lora.ops.triton_ops import fused_moe_lora as _op + + _op( + output, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_ids, + num_active_loras, + adapter_enabled, + block_size, + 32, + 64, + 1, + 4, + 3, + 1, + block_size, + 32, + 64, + 1, + 4, + 3, + 1, + False, + False, + 0, + ) + + +@pytest.mark.parametrize("device", DEVICES) +def test_fused_moe_lora_kernel_no_active_loras(device): + """SORTED path: all entries in lora_ids are -1 -> every program must + early-exit at the lora_id<0 check and output must remain at its + pre-call residual value byte-for-byte.""" + torch.set_default_device(device) + set_random_seed(0) + num_tokens, top_k, E, max_loras, R, K, N = 32, 2, 8, 4, 16, 1024, 1024 + block_size, num_slices, dtype = 16, 2, torch.bfloat16 + + ( + topk_ids, + topk_weights, + token_lora_mapping, + _, + lora_a_stacked, + lora_b_stacked, + hidden_states, + ) = _build_one_shot_inputs( + num_tokens, + top_k, + E, + max_loras, + R, + K, + N, + num_slices, + block_size, + dtype, + ) + + lora_ids = torch.full((max_loras + 1,), -1, dtype=torch.int32) + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) + num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") + residual = torch.randn((num_tokens, top_k, N), dtype=dtype) * 0.1 + output = residual.clone() + + # Sorted path with empty alignment: build minimal valid metadata. + max_pad = topk_ids.numel() + E * (block_size - 1) + max_pad = round_up(max_pad, block_size) + max_blocks = CEILDIV(max_pad, block_size) + sorted_token_ids = torch.zeros((max_loras, max_pad), dtype=torch.int32) + expert_ids = torch.full((max_loras, max_blocks), -1, dtype=torch.int32) + num_post = torch.zeros((max_loras,), dtype=torch.int32) + _call_one_shot( + output, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_post, + token_lora_mapping, + R, + top_k, + lora_ids, + num_active_loras, + adapter_enabled, + block_size, + ) + torch.testing.assert_close(output, residual, atol=0, rtol=0) + + +@pytest.mark.parametrize("device", DEVICES) +def test_fused_moe_lora_kernel_naive_no_lora_tokens(device): + """NAIVE path: token_lora_mapping is all -1 -> lora_id resolves to -1 + in the kernel and every program early-exits. Residual preserved.""" + torch.set_default_device(device) + set_random_seed(0) + num_tokens, top_k, E, max_loras, R, K, N = 8, 2, 64, 8, 16, 1024, 1024 + num_slices, dtype = 2, torch.bfloat16 + + (topk_ids, topk_weights, _, _, lora_a_stacked, lora_b_stacked, hidden_states) = ( + _build_one_shot_inputs( + num_tokens, + top_k, + E, + max_loras, + R, + K, + N, + num_slices, + 16, + dtype, + ) + ) + token_lora_mapping = torch.full((num_tokens,), -1, dtype=torch.int32) + lora_ids = torch.full((max_loras + 1,), -1, dtype=torch.int32) + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) + num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") + residual = torch.randn((num_tokens, top_k, N), dtype=dtype) * 0.1 + output = residual.clone() + expert_ids = topk_ids.reshape(-1).contiguous() + _call_one_shot( + output, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + None, + expert_ids, + None, + token_lora_mapping, + R, + top_k, + lora_ids, + num_active_loras, + adapter_enabled, + 16, + ) + torch.testing.assert_close(output, residual, atol=0, rtol=0) + + +@pytest.mark.parametrize("device", DEVICES) +def test_fused_moe_lora_kernel_all_disabled(device): + """adapter_enabled is all-zero: every program must early-exit at the + enabled check; residual preserved.""" + torch.set_default_device(device) + set_random_seed(0) + num_tokens, top_k, E, max_loras, R, K, N = 32, 2, 8, 4, 32, 1024, 1024 + block_size, num_slices, dtype = 16, 2, torch.bfloat16 + + ( + topk_ids, + topk_weights, + token_lora_mapping, + lora_ids, + lora_a_stacked, + lora_b_stacked, + hidden_states, + ) = _build_one_shot_inputs( + num_tokens, + top_k, + E, + max_loras, + R, + K, + N, + num_slices, + block_size, + dtype, + ) + adapter_enabled = torch.zeros(max_loras + 1, dtype=torch.int32) + num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") + residual = torch.randn((num_tokens, top_k, N), dtype=dtype) * 0.1 + output = residual.clone() + + expert_ids = topk_ids.reshape(-1).contiguous() + _call_one_shot( + output, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + None, + expert_ids, + None, + token_lora_mapping, + R, + top_k, + lora_ids, + num_active_loras, + adapter_enabled, + block_size, + ) + torch.testing.assert_close(output, residual, atol=0, rtol=0) + + +@pytest.mark.parametrize("device", DEVICES) +def test_fused_moe_lora_kernel_zero_grid_no_crash(device): + """num_active_loras=0 (or num_slices=0) would otherwise launch a grid + with a zero dimension. one-shot wrapper must short-circuit before launch.""" + torch.set_default_device(device) + set_random_seed(0) + num_tokens, top_k, E, max_loras, R, K, N = 8, 2, 8, 4, 16, 1024, 1024 + block_size, num_slices, dtype = 16, 2, torch.bfloat16 + + ( + topk_ids, + topk_weights, + token_lora_mapping, + lora_ids, + lora_a_stacked, + lora_b_stacked, + hidden_states, + ) = _build_one_shot_inputs( + num_tokens, + top_k, + E, + max_loras, + R, + K, + N, + num_slices, + block_size, + dtype, + ) + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) + num_active_loras = torch.tensor([0], dtype=torch.int32, device="cpu") + residual = torch.randn((num_tokens, top_k, N), dtype=dtype) * 0.1 + output = residual.clone() + + # sorted path is the one that uses num_active_loras for grid axis 2 + max_pad = topk_ids.numel() + E * (block_size - 1) + max_pad = round_up(max_pad, block_size) + max_blocks = CEILDIV(max_pad, block_size) + sorted_token_ids = torch.zeros((max_loras, max_pad), dtype=torch.int32) + expert_ids = torch.full((max_loras, max_blocks), -1, dtype=torch.int32) + num_post = torch.zeros((max_loras,), dtype=torch.int32) + _call_one_shot( + output, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_post, + token_lora_mapping, + R, + top_k, + lora_ids, + num_active_loras, + adapter_enabled, + block_size, + ) + torch.testing.assert_close(output, residual, atol=0, rtol=0) + + +@pytest.mark.parametrize("device", DEVICES) +def test_fused_moe_lora_kernel_rejects_bad_block_size_m(device): + """one-shot must surface a clear assertion when shrink_block_size_m is not + a power of 2 / less than 16, instead of the cryptic Triton compile + failure (`arange's range must be a power of 2`).""" + torch.set_default_device(device) + set_random_seed(0) + num_tokens, top_k, E, max_loras, R, K, N = 32, 2, 8, 4, 16, 1024, 1024 + num_slices, dtype = 2, torch.bfloat16 + block_size = 24 # NOT a power of 2 + + ( + topk_ids, + topk_weights, + token_lora_mapping, + lora_ids, + lora_a_stacked, + lora_b_stacked, + hidden_states, + ) = _build_one_shot_inputs( + num_tokens, + top_k, + E, + max_loras, + R, + K, + N, + num_slices, + 16, + dtype, + ) + # Build sorted-mode metadata at block_size=16 so shapes are sane, + # but pass block_size=24 to the op (the buggy combination). + max_pad = topk_ids.numel() + E * (16 - 1) + max_pad = round_up(max_pad, 16) + max_blocks = CEILDIV(max_pad, 16) + sorted_token_ids = torch.zeros((max_loras, max_pad), dtype=torch.int32) + expert_ids = torch.full((max_loras, max_blocks), -1, dtype=torch.int32) + num_post = torch.zeros((max_loras,), dtype=torch.int32) + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) + num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") + output = torch.zeros((num_tokens, top_k, N), dtype=dtype) + + with pytest.raises(AssertionError, match="shrink_block_size_m"): + _call_one_shot( + output, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_post, + token_lora_mapping, + R, + top_k, + lora_ids, + num_active_loras, + adapter_enabled, + block_size, + ) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 42f53b200fa3..8793d6113b48 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -105,6 +105,365 @@ def _get_c_ptrs( _LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {} +# --------------------------------------------------------------------------- +# Fully-fused MoE-LoRA kernel (one-shot): shrink + expand combined into a single +# launch with the rank-dim intermediate kept in registers. Used by the fast +# path of `_fused_moe_lora` for `fully_sharded=False`. The legacy two-kernel +# path (`_fused_moe_lora_kernel` above) is retained for `fully_sharded=True` +# because that path needs to materialise the intermediate cache for an +# all_reduce / all_gather between shrink and expand. +# --------------------------------------------------------------------------- + + +@triton.heuristics({"EVEN_K": lambda args: args["K"] % args["BLOCK_K"] == 0}) +@triton.jit +def _fused_moe_lora_one_shot_kernel( + # ---- pointers ---- + x_ptr, + A_ptrs, + B_ptrs, + out_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + token_lora_mapping_ptr, + lora_ids_ptr, + adapter_enabled_ptr, + # ---- dims ---- + N, + K, + num_valid_tokens, + top_k_num, + max_loras, + # ---- strides ---- + stride_xm, + stride_xk, + stride_A_lora, + stride_A_expert, + stride_A_r, + stride_A_k, + stride_B_lora, + stride_B_expert, + stride_B_n, + stride_B_r, + stride_om, + stride_on, + stride_tl_, + stride_el, + # ---- scalar ---- + slice_n_offset, + # ---- constexpr (set per call) ---- + token_mapping_factor: tl.constexpr, + naive_block_assignment: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_R: tl.constexpr, + actual_rank: tl.constexpr, + NPID_FACTOR: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, +): + pid_full = tl.program_id(axis=0) + pid_m = pid_full // NPID_FACTOR + pid_n_outer = pid_full % NPID_FACTOR + slice_id = tl.program_id(axis=1) + lora_idx = tl.program_id(axis=2) + + # Resolve lora_id. + if naive_block_assignment: + token_idx_for_lora = pid_m // top_k_num + lora_id = tl.load(token_lora_mapping_ptr + token_idx_for_lora) + else: + lora_id = tl.load(lora_ids_ptr + lora_idx) + if lora_id < 0: + return + if lora_id >= max_loras: + return + enabled = tl.load(adapter_enabled_ptr + lora_id) + if enabled == 0: + return + + if not naive_block_assignment: + ntpp = tl.load(num_tokens_post_padded_ptr + lora_id) + if pid_m * BLOCK_M >= ntpp: + return + + # Resolve expert_id. + if naive_block_assignment: + expert_id = tl.load(expert_ids_ptr + pid_m) + else: + ind = lora_id * stride_el + pid_m + expert_id = tl.load( + expert_ids_ptr + ind, mask=ind < max_loras * stride_el, other=-1 + ) + if expert_id < 0: + return + + # Compute offs_token (flat token ids). + offs = tl.arange(0, BLOCK_M).to(tl.int64) + if naive_block_assignment: + offs_token = tl.where(offs == 0, pid_m, num_valid_tokens) + else: + offs_token_id = pid_m * BLOCK_M + offs + token_ind = stride_tl_ * lora_id + offs_token_id + offs_token = tl.load( + sorted_token_ids_ptr + token_ind, + mask=token_ind < max_loras * stride_tl_, + other=num_valid_tokens, + ) + token_mask = offs_token < num_valid_tokens + + # N range owned by this program. Splitting [0, N) into NPID_FACTOR + # contiguous outer blocks lets us scale parallelism for small batches. + n_per_outer = tl.cdiv(N, NPID_FACTOR) + n_lo = pid_n_outer * n_per_outer + n_hi = tl.minimum((pid_n_outer + 1) * n_per_outer, N) + if n_lo >= N: + return + + # Slice pointers. + cur_A_ptr = tl.load(A_ptrs + slice_id).to(tl.pointer_type(out_ptr.dtype.element_ty)) + cur_B_ptr = tl.load(B_ptrs + slice_id).to(tl.pointer_type(out_ptr.dtype.element_ty)) + + A_base = cur_A_ptr + lora_id * stride_A_lora + expert_id * stride_A_expert + B_base = cur_B_ptr + lora_id * stride_B_lora + expert_id * stride_B_expert + + # SHRINK: tmp[BLOCK_M, BLOCK_R] = x @ A^T, accumulated in fp32 registers. + offs_r = tl.arange(0, BLOCK_R) + rank_mask = offs_r < actual_rank + # Clamp rank offsets so OOB rows of A / B map to address 0; the mask + # zeros the loaded values. Required when BLOCK_R > actual_rank + # (e.g. rank=4 padded to 16) -- without clamping, tl.load would address + # the next expert's memory. + safe_offs_r = tl.where(rank_mask, offs_r, 0) + offs_k = tl.arange(0, BLOCK_K) + + offs_x_row = offs_token // token_mapping_factor + x_ptrs = x_ptr + offs_x_row[:, None] * stride_xm + offs_k[None, :] * stride_xk + a_ptrs = A_base + offs_k[:, None] * stride_A_k + safe_offs_r[None, :] * stride_A_r + + tmp = tl.zeros((BLOCK_M, BLOCK_R), dtype=tl.float32) + if EVEN_K: + for _ in range(0, K, BLOCK_K): + x = tl.load(x_ptrs, mask=token_mask[:, None], other=0.0) + a = tl.load(a_ptrs, mask=rank_mask[None, :], other=0.0) + tmp += tl.dot(x, a) + x_ptrs += BLOCK_K * stride_xk + a_ptrs += BLOCK_K * stride_A_k + else: + for kb in range(0, K, BLOCK_K): + k_remain = K - kb + k_mask = offs_k < k_remain + x = tl.load(x_ptrs, mask=token_mask[:, None] & k_mask[None, :], other=0.0) + a = tl.load(a_ptrs, mask=k_mask[:, None] & rank_mask[None, :], other=0.0) + tmp += tl.dot(x, a) + x_ptrs += BLOCK_K * stride_xk + a_ptrs += BLOCK_K * stride_A_k + + tmp_typed = tmp.to(out_ptr.dtype.element_ty) + + # EXPAND: out[tokens, n] += tmp @ B^T, looped over BLOCK_N tiles within + # this program's [n_lo, n_hi). The (offs_n < n_hi) mask is required + # whenever BLOCK_N > n_per_outer to keep adjacent outer blocks from + # writing into each other's columns. + if MUL_ROUTED_WEIGHT: + moe_w = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0).to( + tl.float32 + ) + + out_slice_base = out_ptr + slice_id * slice_n_offset + + for n_start in range(n_lo, n_hi, BLOCK_N): + offs_n = n_start + tl.arange(0, BLOCK_N) + n_mask = (offs_n < N) & (offs_n < n_hi) + + b_ptrs = ( + B_base + safe_offs_r[:, None] * stride_B_r + offs_n[None, :] * stride_B_n + ) + b = tl.load(b_ptrs, mask=rank_mask[:, None] & n_mask[None, :], other=0.0) + + acc = tl.dot(tmp_typed, b) # (BLOCK_M, BLOCK_N) fp32 + if MUL_ROUTED_WEIGHT: + acc = acc * moe_w[:, None] + + out_ptrs = ( + out_slice_base + + offs_token[:, None] * stride_om + + offs_n[None, :] * stride_on + ) + out_mask = token_mask[:, None] & n_mask[None, :] + prev = tl.load(out_ptrs, mask=out_mask, other=0.0) + tl.store(out_ptrs, prev + acc.to(out_ptr.dtype.element_ty), mask=out_mask) + + +def _run_fused_moe_lora_one_shot( + output: torch.Tensor, + qcurr_hidden_states: torch.Tensor, + lora_a_stacked: list[torch.Tensor], + lora_b_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor | None, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor | None, + token_lora_mapping: torch.Tensor, + max_lora_rank: int, + top_k_num: int, + lora_ids: torch.Tensor, + num_active_loras: torch.Tensor, + adapter_enabled: torch.Tensor, + mul_routed_weight: bool, + block_size_m: int, +) -> None: + """Fast-path wrapper: launches one fused shrink+expand kernel. + + The shape contract matches `_fused_moe_lora`. `output` is expected to + have shape `(num_tokens, top_k_num, num_slices * N_per_slice)` and is + accumulated in-place. Only used when `fully_sharded=False` and the + caller's `offset` is 0 (which it always is in that case -- see + vllm/lora/layers/fused_moe.py). + """ + num_slices = len(lora_a_stacked) + device = qcurr_hidden_states.device + + A0 = lora_a_stacked[0] + B0 = lora_b_stacked[0] + max_loras_w = A0.shape[0] + rank = A0.shape[2] + K = A0.shape[3] + N_per_slice = B0.shape[2] + + # rank padding is to next pow2 with a floor of 16 (tensor-core minimum + # K-dim). Beyond 128 the (BLOCK_M, BLOCK_R) accumulator outgrows the + # register file; rank tiling would be needed but is out of scope for + # this kernel. + assert rank <= 128, ( + f"fused_moe_lora_one_shot supports max_lora_rank<=128; got rank={rank}" + ) + BLOCK_R = max(triton.next_power_of_2(rank), 16) + + naive = sorted_token_ids is None + if naive: + EM_grid = topk_weights.numel() + BLOCK_M = 16 + stride_tl_ = 0 + stride_el = 0 + grid_lora_dim = 1 + else: + EM_grid = sorted_token_ids.shape[1] + # BLOCK_M must equal moe_lora_align_block_size's block_size. The + # caller passes that explicitly; deriving it from tensor shapes is + # unsafe because sorted_token_ids.shape[1] is the raw padded length + # (not necessarily a multiple of block_size — e.g. OLMoE prefill + # produces sorted=139200 with expert_ids=1088 and block_size=128). + # tl.arange and tl.dot need block_size_m to be a power of 2 and at + # least 16. The Python-side assertion gives a clearer error than + # the cryptic Triton compile failure. + assert block_size_m >= 16 and (block_size_m & (block_size_m - 1)) == 0, ( + f"shrink_block_size_m must be a power of 2 and >=16; got {block_size_m}" + ) + BLOCK_M = block_size_m + stride_tl_ = sorted_token_ids.stride(0) + stride_el = expert_ids.stride(0) + grid_lora_dim = int(num_active_loras.item()) + + # Empty-work guards: the grid would otherwise have a zero dimension, + # which Triton rejects. None of these is a hot path in production -- a + # batch with zero tokens, an EM_grid of zero, or zero active LoRAs all + # mean there's nothing to add to `output`. + if EM_grid == 0 or grid_lora_dim == 0 or num_slices == 0: + return + + token_mapping_factor = 1 if mul_routed_weight else top_k_num + + A_ptrs = _get_ptr(lora_a_stacked, device) + B_ptrs = _get_ptr(lora_b_stacked, device) + + # Flatten (num_tokens, top_k) → flat_token axis. The kernel addresses + # output via offs_token * stride_om, which is correct iff the dim-0 / + # dim-1 strides collapse cleanly: stride(0) == top_k * stride(1). All + # production callers pass contiguous output, so this always holds; the + # explicit check guards against future regressions where a non-trivial + # view (e.g. permute) would silently break in-place accumulation. + assert output.dim() == 3, f"output must be 3-D, got {output.shape}" + assert output.stride(0) == output.shape[1] * output.stride(1), ( + "fused_moe_lora_one_shot requires output.stride(0) == top_k*stride(1); " + f"got shape={output.shape} strides={output.stride()}" + ) + out_view = output.view(-1, output.shape[-1]) + M_blocks = triton.cdiv(EM_grid, BLOCK_M) if not naive else EM_grid + + # NPID_FACTOR heuristic: scale N-axis parallelism when base CTA count is + # short of saturating the SM array. Cap by the cost of redundant shrink. + sm_count = torch.cuda.get_device_properties(device).multi_processor_count + base_programs = max(M_blocks * num_slices * grid_lora_dim, 1) + shrink_ratio = K / max(K + N_per_slice, 1) + max_npid_by_budget = max(1, int(1.5 / max(shrink_ratio, 1e-3)) + 1) + target = 2 * sm_count + if base_programs >= int(1.5 * sm_count): + npid = 1 + else: + npid_occ = max(1, min(16, (target + base_programs - 1) // base_programs)) + npid = min(npid_occ, max_npid_by_budget) + npid = max(1, min(npid, max(1, N_per_slice // 128))) + + # Robust defaults across the prefill regime (H100/H200, bf16/fp16). + # NPID > 1 is the small-M / under-saturated path -- more warps + a + # deeper pipeline help amortise the inner-N expand loop. + if npid > 1: + block_n, block_k, nw, ns = 128, 64, 8, 4 + else: + block_n, block_k, nw, ns = 128, 64, 4, 3 + + grid = (M_blocks * npid, num_slices, grid_lora_dim) + + _fused_moe_lora_one_shot_kernel[grid]( + qcurr_hidden_states, + A_ptrs, + B_ptrs, + out_view, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + token_lora_mapping, + lora_ids, + adapter_enabled, + N_per_slice, + K, + topk_weights.numel(), + top_k_num, + max_loras_w, + qcurr_hidden_states.stride(0), + qcurr_hidden_states.stride(1), + A0.stride(0), + A0.stride(1), + A0.stride(2), + A0.stride(3), + B0.stride(0), + B0.stride(1), + B0.stride(2), + B0.stride(3), + out_view.stride(0), + out_view.stride(1), + stride_tl_, + stride_el, + N_per_slice, + token_mapping_factor=token_mapping_factor, + naive_block_assignment=naive, + MUL_ROUTED_WEIGHT=mul_routed_weight, + BLOCK_M=BLOCK_M, + BLOCK_R=BLOCK_R, + actual_rank=rank, + NPID_FACTOR=npid, + BLOCK_N=block_n, + BLOCK_K=block_k, + num_warps=nw, + num_stages=ns, + ) + + def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device): """ `_LORA_PTR_DICT` collects the required information during `profile_run`, @@ -728,6 +1087,37 @@ def _fused_moe_lora( ) assert output.shape[0] == topk_weights.shape[0] assert top_k_num == topk_weights.shape[1] + + # Fast path: single fused kernel keeps the rank-dim intermediate in + # registers and avoids the HBM round-trip of the legacy two-kernel + # implementation. fully_sharded=True still needs the materialised + # intermediate cache so that an all_reduce / all_gather can flow + # between shrink and expand, so it falls through to the legacy path. + if not fully_sharded: + # shrink/expand BLOCK_SIZE_M must match the block_size that + # moe_lora_align_block_size used; both shrink and expand pass the + # same value (asserted by `shrink_block_size_m == expand_block_size_m` + # below). + _run_fused_moe_lora_one_shot( + output, + qcurr_hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_ids, + num_active_loras, + adapter_enabled, + mul_routed_weight, + shrink_block_size_m, + ) + return + device = qcurr_hidden_states.device num_slices = len(lora_a_stacked) w1_lora_b_stacked = lora_b_stacked[0] From 977f1a01e1764835dab2959b520501730622975f Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 27 Apr 2026 11:00:42 +0000 Subject: [PATCH 02/10] Move Signed-off-by: Jee Jee Li --- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 8793d6113b48..0f829c9fb8ee 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -343,6 +343,7 @@ def _run_fused_moe_lora_one_shot( ) BLOCK_R = max(triton.next_power_of_2(rank), 16) + num_experts = A0.shape[1] naive = sorted_token_ids is None if naive: EM_grid = topk_weights.numel() @@ -408,13 +409,22 @@ def _run_fused_moe_lora_one_shot( npid = min(npid_occ, max_npid_by_budget) npid = max(1, min(npid, max(1, N_per_slice // 128))) - # Robust defaults across the prefill regime (H100/H200, bf16/fp16). + # Robust defaults across the prefill regime (H100/H200/B200, bf16/fp16). # NPID > 1 is the small-M / under-saturated path -- more warps + a # deeper pipeline help amortise the inner-N expand loop. if npid > 1: - block_n, block_k, nw, ns = 128, 64, 8, 4 + block_n, nw, ns = 128, 8, 4 else: - block_n, block_k, nw, ns = 128, 64, 4, 3 + block_n, nw, ns = 128, 4, 3 + # BLOCK_K choice: when each expert sees enough tokens, the wider K tile + # halves the K-loop trip count and amortises load/MMA setup -- worth ~5 + # to 10% on GB200 for prefill-sized inputs. For sparsely-populated + # routings (e.g. M=16 mixtral, ~4 tokens/expert) the wider tile inflates + # per-program startup and most blocks early-exit anyway, so we keep the + # narrower tile. Threshold derived from GB200 sweeps over mixtral + # (E=8) and qwen3moe (E=64). + work_per_expert = topk_weights.numel() / max(num_experts, 1) + block_k = 128 if work_per_expert >= 16 else 64 grid = (M_blocks * npid, num_slices, grid_lora_dim) From 1f26ac78dc65e2d876a4962e2ac08b6a4444eb5d Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 1 May 2026 13:58:18 +0000 Subject: [PATCH 03/10] Support dual streams Signed-off-by: Jee Jee Li --- tests/lora/test_fused_moe_lora_kernel.py | 402 ++++++++++++++++++ vllm/envs.py | 2 + vllm/lora/layers/base_linear.py | 15 +- vllm/lora/layers/fused_moe.py | 68 +++ vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 34 +- vllm/lora/punica_wrapper/punica_gpu.py | 6 + vllm/lora/punica_wrapper/punica_xpu.py | 2 + .../layers/fused_moe/fused_moe.py | 236 ++++++---- .../layers/fused_moe/lora_context.py | 10 + .../layers/fused_moe/lora_experts_mixin.py | 4 + .../layers/fused_moe/runner/moe_runner.py | 2 + 11 files changed, 696 insertions(+), 85 deletions(-) diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index 720db2f788fa..d08159a0c85b 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -141,6 +141,7 @@ def use_fused_moe_lora_kernel( block_size, fully_sharded=False, offset=0, + add_inputs=True, ): max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) @@ -222,6 +223,7 @@ def use_fused_moe_lora_kernel( mul_routed_weight, fully_sharded=fully_sharded, offset=offset, + add_inputs=add_inputs, ) @@ -371,6 +373,7 @@ def use_fused_moe_lora_kernel_naive( block_size, fully_sharded=False, offset=0, + add_inputs=True, ): """ Test helper for naive_block_assignment path. @@ -435,6 +438,7 @@ def use_fused_moe_lora_kernel_naive( mul_routed_weight=mul_routed_weight, fully_sharded=fully_sharded, offset=offset, + add_inputs=add_inputs, ) @@ -988,6 +992,7 @@ def _call_one_shot( num_active_loras, adapter_enabled, block_size, + add_inputs=True, ): """Direct call into fused_moe_lora with one-shot-routed defaults.""" from vllm.lora.ops.triton_ops import fused_moe_lora as _op @@ -1024,6 +1029,7 @@ def _call_one_shot( False, False, 0, + add_inputs, ) @@ -1318,3 +1324,399 @@ def test_fused_moe_lora_kernel_rejects_bad_block_size_m(device): adapter_enabled, block_size, ) + + +@pytest.mark.parametrize("naive", [True, False]) +@pytest.mark.parametrize("num_slices", [1, 2]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +def test_fused_moe_lora_kernel_add_inputs_parity(naive, num_slices, dtype, device): + """Prerequisite for the dual-stream LoRA path: add_inputs=False must + produce the LoRA delta only (matching what add_inputs=True would have + added on top of a residual). Concretely: + + run(add_inputs=True, output=residual) == residual + + run(add_inputs=False, + output=zeros) + + Covers both the SORTED path and the NAIVE block-assignment path.""" + torch.set_default_device(device) + set_random_seed(0) + + if naive: + # Naive path is gated by num_tokens * top_k * 8 <= num_experts * max_loras. + num_tokens, top_k, E, max_loras, R, K, N = 4, 2, 64, 8, 16, 1024, 512 + else: + num_tokens, top_k, E, max_loras, R, K, N = 32, 2, 8, 4, 16, 1024, 512 + block_size = 16 + + ( + topk_ids, + topk_weights, + token_lora_mapping, + lora_ids, + lora_a_stacked, + lora_b_stacked, + hidden_states, + ) = _build_one_shot_inputs( + num_tokens, top_k, E, max_loras, R, K, N, num_slices, block_size, dtype + ) + + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) + num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") + + if naive: + sorted_token_ids = None + expert_ids = topk_ids.reshape(-1).contiguous() + num_tokens_post_padded = None + else: + max_pad = topk_ids.numel() + E * (block_size - 1) + max_pad = round_up(max_pad, block_size) + max_blocks = CEILDIV(max_pad, block_size) + sorted_token_ids = torch.empty((max_loras * max_pad,), dtype=torch.int32) + expert_ids = torch.empty((max_loras * max_blocks,), dtype=torch.int32) + num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32) + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + E, + block_size, + max_loras, + max_pad, + max_blocks, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + adapter_enabled, + lora_ids, + ) + sorted_token_ids = sorted_token_ids.view(max_loras, -1) + expert_ids = expert_ids.view(max_loras, -1) + + # Use a non-trivial residual so the in-place add semantics are exercised + # (a zero residual would make add_inputs=True and add_inputs=False + # trivially equal even if the kernel ignored the flag). + residual = torch.randn((num_tokens, top_k, num_slices * N), dtype=dtype) * 0.1 + + # Path A: run with add_inputs=True on top of the residual. + out_inplace = residual.clone() + _call_one_shot( + out_inplace, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + token_lora_mapping, + R, + top_k, + lora_ids, + num_active_loras, + adapter_enabled, + block_size, + add_inputs=True, + ) + + # Path B: run with add_inputs=False into a zero buffer to get the delta only. + out_delta = torch.zeros((num_tokens, top_k, num_slices * N), dtype=dtype) + _call_one_shot( + out_delta, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + token_lora_mapping, + R, + top_k, + lora_ids, + num_active_loras, + adapter_enabled, + block_size, + add_inputs=False, + ) + + # The kernel writes only into rows touched by some (lora_id, expert_id) + # tile; rows that no program owns retain their pre-call value. So an + # "untouched residual + 0" row in path A must equal "0 + untouched 0" + # in path B. The equality below holds for both touched and untouched + # rows: residual + delta_only == residual_in_inplace_run. + torch.testing.assert_close(out_inplace, residual + out_delta, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("device", DEVICES) +def test_fused_moe_lora_kernel_dual_stream_smoke(device): + """End-to-end smoke for the dual-stream pattern used by + FusedMoEWithLoRA: run the LoRA fast path on an aux CUDA stream with + add_inputs=False, then sum into a residual on the default stream + via a cuda.Event-synchronised .add_(). The result must match the + in-place add_inputs=True path on the default stream. + + This complements test_fused_moe_lora_kernel_add_inputs_parity by + exercising the actual stream-coordination machinery (event.record / + event.wait) that TritonExperts.apply uses; the parity test only + proves the math. + """ + torch.set_default_device(device) + set_random_seed(0) + + num_tokens, top_k, E, max_loras, R, K, N = 32, 2, 8, 4, 16, 1024, 512 + block_size, num_slices, dtype = 16, 2, torch.bfloat16 + + ( + topk_ids, + topk_weights, + token_lora_mapping, + lora_ids, + lora_a_stacked, + lora_b_stacked, + hidden_states, + ) = _build_one_shot_inputs( + num_tokens, top_k, E, max_loras, R, K, N, num_slices, block_size, dtype + ) + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) + num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") + + max_pad = topk_ids.numel() + E * (block_size - 1) + max_pad = round_up(max_pad, block_size) + max_blocks = CEILDIV(max_pad, block_size) + sorted_token_ids = torch.empty((max_loras * max_pad,), dtype=torch.int32) + expert_ids = torch.empty((max_loras * max_blocks,), dtype=torch.int32) + num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32) + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + E, + block_size, + max_loras, + max_pad, + max_blocks, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + adapter_enabled, + lora_ids, + ) + sorted_token_ids = sorted_token_ids.view(max_loras, -1) + expert_ids = expert_ids.view(max_loras, -1) + + residual = torch.randn((num_tokens, top_k, num_slices * N), dtype=dtype) * 0.1 + + # Reference: single-stream in-place add_inputs=True. + out_ref = residual.clone() + _call_one_shot( + out_ref, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + token_lora_mapping, + R, + top_k, + lora_ids, + num_active_loras, + adapter_enabled, + block_size, + add_inputs=True, + ) + + # Dual-stream pattern: LoRA delta on aux stream, then join + add_. + aux_stream = torch.cuda.Stream() + event0 = torch.cuda.Event() + event1 = torch.cuda.Event() + out_dual = residual.clone() + delta = torch.zeros_like(residual) + + event0.record() + with torch.cuda.stream(aux_stream): + event0.wait() + _call_one_shot( + delta, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + token_lora_mapping, + R, + top_k, + lora_ids, + num_active_loras, + adapter_enabled, + block_size, + add_inputs=False, + ) + event1.record() + event1.wait() + out_dual.add_(delta) + + torch.testing.assert_close(out_ref, out_dual, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("device", DEVICES) +def test_fused_moe_lora_kernel_dual_stream_cuda_graph(device): + """The MoE-LoRA dual-stream block (LoRA delta on aux stream + .add_() + on default stream, joined via cuda.Event) must be capturable into a + torch.cuda.CUDAGraph and replayable with the same numerical result + as eager execution. This is the prerequisite for the dual-stream + path to work under vLLM's decode-time CUDA-graph capture. + """ + torch.set_default_device(device) + set_random_seed(0) + + num_tokens, top_k, E, max_loras, R, K, N = 32, 2, 8, 4, 16, 1024, 512 + block_size, num_slices, dtype = 16, 2, torch.bfloat16 + + ( + topk_ids, + topk_weights, + token_lora_mapping, + lora_ids, + lora_a_stacked, + lora_b_stacked, + hidden_states, + ) = _build_one_shot_inputs( + num_tokens, top_k, E, max_loras, R, K, N, num_slices, block_size, dtype + ) + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) + num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") + + max_pad = topk_ids.numel() + E * (block_size - 1) + max_pad = round_up(max_pad, block_size) + max_blocks = CEILDIV(max_pad, block_size) + sorted_token_ids = torch.empty((max_loras * max_pad,), dtype=torch.int32) + expert_ids = torch.empty((max_loras * max_blocks,), dtype=torch.int32) + num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32) + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + E, + block_size, + max_loras, + max_pad, + max_blocks, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + adapter_enabled, + lora_ids, + ) + sorted_token_ids = sorted_token_ids.view(max_loras, -1) + expert_ids = expert_ids.view(max_loras, -1) + + # Persistent stream + events (must outlive the graph; created in capture + # would be replayed against fresh objects each replay, defeating the + # point). + aux_stream = torch.cuda.Stream() + event0 = torch.cuda.Event() + event1 = torch.cuda.Event() + + residual = torch.randn((num_tokens, top_k, num_slices * N), dtype=dtype) * 0.1 + + def _run_dual_stream(out_buf: torch.Tensor) -> None: + # Mirrors the structure of TritonExperts.apply's w13 / w2 blocks. + delta = torch.zeros_like(out_buf) + event0.record() + with torch.cuda.stream(aux_stream): + event0.wait() + _call_one_shot( + delta, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + token_lora_mapping, + R, + top_k, + lora_ids, + num_active_loras, + adapter_enabled, + block_size, + add_inputs=False, + ) + event1.record() + event1.wait() + out_buf.add_(delta) + + # Warm up: triton compile cache must be primed before capture, otherwise + # JIT compilation gets recorded into the graph (or fails capture). + warm = residual.clone() + _run_dual_stream(warm) + torch.cuda.synchronize() + + # Eager baseline + out_eager = residual.clone() + _run_dual_stream(out_eager) + torch.cuda.synchronize() + + # Captured + replay. Capture stream is separate from the default; events + # used inside must be recorded/waited on the capture stream (or aux + # stream that's dependent on it). torch.cuda.graph handles the stream + # accounting automatically. + g = torch.cuda.CUDAGraph() + out_graph = residual.clone() + capture_stream = torch.cuda.Stream() + capture_stream.wait_stream(torch.cuda.current_stream()) + with torch.cuda.stream(capture_stream): + with torch.cuda.graph(g, stream=capture_stream): + _run_dual_stream(out_graph) + torch.cuda.current_stream().wait_stream(capture_stream) + torch.cuda.synchronize() + + # First replay: the buffer was already populated by the capture itself, + # which is the eager-style write done during stream-recording. To + # validate replay semantics, reset the output and replay. + out_graph.copy_(residual) + g.replay() + torch.cuda.synchronize() + + torch.testing.assert_close(out_graph, out_eager, atol=1e-2, rtol=1e-2) + + # Replay a second time to confirm graph state is replayable repeatedly. + out_graph.copy_(residual) + g.replay() + torch.cuda.synchronize() + torch.testing.assert_close(out_graph, out_eager, atol=1e-2, rtol=1e-2) + + +def test_moe_forward_custom_op_registered(): + """The dual-stream MoE-LoRA path in TritonExperts.apply (see + vllm/model_executor/layers/fused_moe/fused_moe.py) relies on the entire + MoE forward being reachable only through `torch.ops.vllm.moe_forward` / + `torch.ops.vllm.moe_forward_shared`, both of which are opaque custom + ops. That opacity is what makes torch.compile / Dynamo stop *before* + seeing our `torch.cuda.stream(...)` / `event.record()/wait()` calls, + so we don't have to wrap the dual-stream block in its own custom op. + + If a future refactor drops these registrations (or reroutes the MoE + forward through a non-opaque path), the dual-stream code would start + triggering Dynamo graph breaks -- or, worse, fail silently under + torch.compile. This test exists to catch that regression at the + invariant level rather than via a flaky end-to-end compile run. + """ + # Importing the module side-effect-registers the ops. + import vllm.model_executor.layers.fused_moe.runner.moe_runner # noqa: F401 + + # Both ops must exist on torch.ops.vllm. + assert hasattr(torch.ops.vllm, "moe_forward"), ( + "torch.ops.vllm.moe_forward is gone. The MoE-LoRA dual-stream " + "path assumed this wrapper made the whole MoE forward opaque to " + "Dynamo. See the NOTE block above the registration in " + "vllm/model_executor/layers/fused_moe/runner/moe_runner.py." + ) + assert hasattr(torch.ops.vllm, "moe_forward_shared"), ( + "torch.ops.vllm.moe_forward_shared is gone. Same dual-stream " + "contract -- see runner/moe_runner.py NOTE." + ) diff --git a/vllm/envs.py b/vllm/envs.py index d55732cb6a8d..b6eff2fb3c31 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1752,6 +1752,8 @@ def _get_or_set_default() -> str: int(os.getenv("VLLM_USE_SIMPLE_KV_OFFLOAD", "0")) ), # Whether to enable dual cuda streams for LoRA computation + # (used by both BaseLinearLayerWithLoRA and FusedMoEWithLoRA to + # overlap the base layer compute with the LoRA fast path). "VLLM_LORA_ENABLE_DUAL_STREAM": lambda: bool( int(os.getenv("VLLM_LORA_ENABLE_DUAL_STREAM", "0")) ), diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py index 68783ae50d4b..f4662bfa8e64 100644 --- a/vllm/lora/layers/base_linear.py +++ b/vllm/lora/layers/base_linear.py @@ -100,7 +100,20 @@ def _init_lora_stream_context(self) -> None: self.layer_name = self.base_layer.prefix + ".lora_linear_async" compilation_config = vllm_config.compilation_config if self.layer_name in compilation_config.static_forward_context: - raise ValueError("Duplicate layer name: {}".format(self.layer_name)) + # TEMP(unblock-end-to-end): Upstream FusedMoE.runner exposes the + # gate via an aliased path (mlp.gate AND mlp.experts.runner.gate + # refer to the same nn.Module), so LoRA module replacement wraps + # the same gate twice. Both wrappers compute self.layer_name + # from base_layer.prefix and thus collide here. The two + # wrappers share a base_layer, so disabling dual-stream on the + # duplicate (it falls through to _apply_sync) is safe and + # preserves correctness; only the first-registered wrapper + # keeps the overlap. TODO(remove once upstream LoRA walker + # de-duplicates by base-layer identity). + self._enable_aux_cuda_stream = False + self._lora_stream = None + self._events = [] + return compilation_config.static_forward_context[self.layer_name] = self def create_lora_weights( diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 284ac54997fb..ea3d1de31490 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -22,9 +22,23 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoDPEPModular, ) +from vllm.platforms import current_platform from .utils import _get_lora_device +# Process-wide singleton aux stream shared by every FusedMoEWithLoRA instance. +# Mirrors the pattern in vllm/lora/layers/base_linear.py: one extra stream is +# enough to overlap two compute streams; allocating one per layer would +# under-utilise the SMs and inflate context-switch cost. +_moe_lora_aux_cuda_stream: torch.cuda.Stream | None = None + + +def _get_moe_lora_aux_cuda_stream() -> torch.cuda.Stream | None: + global _moe_lora_aux_cuda_stream + if _moe_lora_aux_cuda_stream is None and current_platform.is_cuda_alike(): + _moe_lora_aux_cuda_stream = torch.cuda.Stream() + return _moe_lora_aux_cuda_stream + class FusedMoEWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: FusedMoE) -> None: @@ -40,6 +54,13 @@ def __init__(self, base_layer: FusedMoE) -> None: self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tensor_model_parallel_rank() self.device = _get_lora_device(base_layer) + # Reuses VLLM_LORA_ENABLE_DUAL_STREAM (the same env that controls + # the linear-LoRA dual-stream path in + # vllm/lora/layers/base_linear.py); enabling it for a deployment + # turns dual-stream on for both linear and MoE LoRA layers in one + # switch. + self._enable_aux_cuda_stream = envs.VLLM_LORA_ENABLE_DUAL_STREAM + self._init_lora_stream_context() # For non-gated MoE (is_act_and_mul=False), only 1 slice is needed # since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3) self._w13_slices = 2 if base_layer.moe_config.is_act_and_mul else 1 @@ -70,7 +91,35 @@ def __init__(self, base_layer: FusedMoE) -> None: FusedMoEModularMethod(self.base_layer.quant_method, moe_kernel) ) + def _init_lora_stream_context(self) -> None: + # Dual-stream is incompatible with fully_sharded MoE LoRA: that path + # routes through the legacy two-kernel flow with an embedded + # all_reduce / all_gather between shrink and expand, where the + # add_inputs=False contract is not wired (see _fused_moe_lora). + # When fully_sharded is enabled at LoRA-config time, we silently + # disable the dual-stream path here so the env var still works. + self._lora_stream: torch.cuda.Stream | None = None + self._events: tuple[torch.cuda.Event, ...] | None = None + if not self._enable_aux_cuda_stream: + return + if not current_platform.is_cuda_alike(): + return + self._lora_stream = _get_moe_lora_aux_cuda_stream() + # 4 events: 2 per (base GEMM, LoRA) pair so w13 and w2 don't reuse + # the same event objects; reuse-within-a-pair is fine because the + # second pair starts only after intermediate_cache1.add_() has joined. + self._events = tuple(torch.cuda.Event() for _ in range(4)) + def _build_lora_context(self): + # Hand the stream/events to the experts only when fully_sharded is + # off (the path the one-shot kernel + add_inputs=False contract + # supports). For fully_sharded we leave aux_stream=None so + # experts.apply() takes the original sequential schedule. + use_dual_stream = ( + self._enable_aux_cuda_stream + and not self.fully_sharded + and self._lora_stream is not None + ) return MoELoRAContext( w13_lora_a_stacked=self.w13_lora_a_stacked, w13_lora_b_stacked=self.w13_lora_b_stacked, @@ -86,6 +135,8 @@ def _build_lora_context(self): local_num_experts=self.base_layer.local_num_experts, punica_wrapper=self.punica_wrapper, use_tuned_config=bool(envs.VLLM_TUNED_CONFIG_FOLDER), + aux_stream=self._lora_stream if use_dual_stream else None, + events=self._events if use_dual_stream else None, ) def _create_lora_a_weights( @@ -339,6 +390,23 @@ def quant_method(self): def is_internal_router(self) -> bool: return self.base_layer.is_internal_router + # TEMP(unblock-end-to-end): Upstream introduced FusedMoE.runner (see + # vllm/model_executor/layers/fused_moe/runner/), which holds gates and + # other linear submodules that LoRA wants to wrap. The LoRA module + # walker (vllm/lora/model_manager.py::_create_lora_modules) iterates + # the model BEFORE wrapping, then calls + # nn.Module.get_submodule("...experts.runner.X") — which uses + # hasattr/getattr — to attach LoRA wrappers. Once "experts" has been + # replaced by FusedMoEWithLoRA, that lookup fails because runner is + # only on base_layer. Forwarding via @property is enough to satisfy + # get_submodule without registering runner as our own child module + # (which would double-count parameters / state_dict entries). + # TODO(remove once upstream LoRA walker handles wrapped modules + # natively). + @property + def runner(self): + return self.base_layer.runner + @classmethod def can_replace_layer( cls, diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 0f829c9fb8ee..61303360363e 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -164,6 +164,7 @@ def _fused_moe_lora_one_shot_kernel( BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, ): pid_full = tl.program_id(axis=0) pid_m = pid_full // NPID_FACTOR @@ -294,8 +295,11 @@ def _fused_moe_lora_one_shot_kernel( + offs_n[None, :] * stride_on ) out_mask = token_mask[:, None] & n_mask[None, :] - prev = tl.load(out_ptrs, mask=out_mask, other=0.0) - tl.store(out_ptrs, prev + acc.to(out_ptr.dtype.element_ty), mask=out_mask) + if ADD_INPUTS: + prev = tl.load(out_ptrs, mask=out_mask, other=0.0) + tl.store(out_ptrs, prev + acc.to(out_ptr.dtype.element_ty), mask=out_mask) + else: + tl.store(out_ptrs, acc.to(out_ptr.dtype.element_ty), mask=out_mask) def _run_fused_moe_lora_one_shot( @@ -315,14 +319,16 @@ def _run_fused_moe_lora_one_shot( adapter_enabled: torch.Tensor, mul_routed_weight: bool, block_size_m: int, + add_inputs: bool = True, ) -> None: """Fast-path wrapper: launches one fused shrink+expand kernel. - The shape contract matches `_fused_moe_lora`. `output` is expected to - have shape `(num_tokens, top_k_num, num_slices * N_per_slice)` and is - accumulated in-place. Only used when `fully_sharded=False` and the - caller's `offset` is 0 (which it always is in that case -- see - vllm/lora/layers/fused_moe.py). + The shape contract matches `_fused_moe_lora`. `output` has shape + `(num_tokens, top_k_num, num_slices * N_per_slice)`. When + `add_inputs=True` (default) the kernel reads-modifies-writes `output` + in place; when `add_inputs=False` the kernel overwrites `output` with + the LoRA delta only. The latter is used by the dual-stream path that + sums LoRA into the base output on a separate stream. """ num_slices = len(lora_a_stacked) device = qcurr_hidden_states.device @@ -469,6 +475,7 @@ def _run_fused_moe_lora_one_shot( NPID_FACTOR=npid, BLOCK_N=block_n, BLOCK_K=block_k, + ADD_INPUTS=add_inputs, num_warps=nw, num_stages=ns, ) @@ -1075,6 +1082,7 @@ def _fused_moe_lora( mul_routed_weight: bool = False, fully_sharded: bool = False, offset: int = 0, + add_inputs: bool = True, ) -> None: assert len(lora_a_stacked) == len(lora_b_stacked) > 0 assert topk_weights.dim() == qcurr_hidden_states.dim() == 2 @@ -1125,9 +1133,20 @@ def _fused_moe_lora( adapter_enabled, mul_routed_weight, shrink_block_size_m, + add_inputs=add_inputs, ) return + # The legacy two-kernel path keeps the historical in-place semantics -- + # `_fused_moe_lora_expand` always sets `ADD_INPUTS=True` so the rank-dim + # cache flowing through all_reduce/all_gather stays consistent. The + # add_inputs=False contract is only wired for the one-shot fast path + # above, so reject it here rather than silently writing wrong results. + assert add_inputs, ( + "fused_moe_lora(add_inputs=False) is only supported on the " + "fully_sharded=False fast path" + ) + device = qcurr_hidden_states.device num_slices = len(lora_a_stacked) w1_lora_b_stacked = lora_b_stacked[0] @@ -1294,6 +1313,7 @@ def _fused_moe_lora_fake( mul_routed_weight: bool = False, fully_sharded: bool = False, offset: int = 0, + add_inputs: bool = True, ) -> None: return diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index 44d1dbd50728..c9dee2afd797 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -409,6 +409,7 @@ def add_lora_fused_moe( fully_sharded: bool = False, offset: int = 0, token_lora_mapping: torch.Tensor | None = None, + add_inputs: bool = True, ): """ Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer. @@ -458,6 +459,7 @@ def add_lora_fused_moe( mul_routed_weight, fully_sharded, offset, + add_inputs, ) def add_lora_w13( @@ -480,6 +482,7 @@ def add_lora_w13( num_slices: int, fully_sharded: bool, use_tuned_config: bool, + add_inputs: bool = True, ) -> tuple[ torch.Tensor | None, torch.Tensor | None, @@ -582,6 +585,7 @@ def add_lora_w13( adapter_enabled, fully_sharded=fully_sharded, token_lora_mapping=token_lora_mapping, + add_inputs=add_inputs, ) return ( @@ -612,6 +616,7 @@ def add_lora_w2( fully_sharded: bool, tp_rank: int, use_tuned_config: bool, + add_inputs: bool = True, ) -> None: import functools @@ -694,4 +699,5 @@ def add_lora_w2( fully_sharded=fully_sharded, offset=offset, token_lora_mapping=token_lora_mapping, + add_inputs=add_inputs, ) diff --git a/vllm/lora/punica_wrapper/punica_xpu.py b/vllm/lora/punica_wrapper/punica_xpu.py index f031e1bfa341..8ad6fec5011e 100644 --- a/vllm/lora/punica_wrapper/punica_xpu.py +++ b/vllm/lora/punica_wrapper/punica_xpu.py @@ -369,6 +369,7 @@ def add_lora_fused_moe( fully_sharded: bool = False, offset: int = 0, token_lora_mapping: torch.Tensor | None = None, + add_inputs: bool = True, ): """ Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer. @@ -418,4 +419,5 @@ def add_lora_fused_moe( mul_routed_weight, fully_sharded, offset, + add_inputs, ) diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 7e7bcc709921..5b4a8d8f0945 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -50,6 +50,7 @@ ) from vllm.platforms import current_platform from vllm.triton_utils import tl, triton +from vllm.utils.multi_stream_utils import maybe_execute_in_parallel from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) @@ -2072,55 +2073,102 @@ def apply( ) ) - invoke_fused_moe_triton_kernel( - hidden_states, - w1, - intermediate_cache1, - a1q_scale, - self.w1_scale, - None, # topk_weights - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, # mul_routed_weights - top_k_num, - config, - compute_type=compute_type, - use_fp8_w8a8=self.quant_config.use_fp8_w8a8, - use_int8_w8a8=self.quant_config.use_int8_w8a8, - use_int8_w8a16=self.quant_config.use_int8_w8a16, - use_int4_w4a16=self.quant_config.use_int4_w4a16, - per_channel_quant=self.per_act_token_quant, - block_shape=self.block_shape, - B_bias=self.w1_bias, - ) - - # LoRA w13: applied to intermediate_cache1 before activation, using - # hidden_states as the lora_a input. moe_lora_align_block_size is - # called once here and results reused for the w2 LoRA below. + # LoRA w13: applied to intermediate_cache1 before activation. When + # the LoRA layer requested a dual-stream schedule, we run base w13 + # GEMM on the default stream and the LoRA fast-path on aux_stream; + # the LoRA writes its delta into a fresh zero buffer (add_inputs= + # False) and we sum it into intermediate_cache1 after both finish. + # + # Note on torch.compile : + # The whole MoE forward is already wrapped in torch.ops.vllm.moe_forward`, + # so we don't need to wrapp the following code as custom op sorted_token_ids_lora = None expert_ids_lora = None num_tokens_post_padded_lora = None token_lora_mapping = None lora_context = self._lora_context - if lora_context is not None: + + def _base_w13_fn(): + invoke_fused_moe_triton_kernel( + hidden_states, + w1, + intermediate_cache1, + a1q_scale, + self.w1_scale, + None, # topk_weights + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, # mul_routed_weights + top_k_num, + config, + compute_type=compute_type, + use_fp8_w8a8=self.quant_config.use_fp8_w8a8, + use_int8_w8a8=self.quant_config.use_int8_w8a8, + use_int8_w8a16=self.quant_config.use_int8_w8a16, + use_int4_w4a16=self.quant_config.use_int4_w4a16, + per_channel_quant=self.per_act_token_quant, + block_shape=self.block_shape, + B_bias=self.w1_bias, + ) + + if lora_context is not None and lora_context.aux_stream is not None: + # add_inputs=False: kernel overwrites lora_delta_w13. zeros (not + # empty) so untouched rows -- e.g. blocks where every program + # early-exits because lora_id<0 -- stay at zero and the trailing + # add_() is a no-op there. + lora_delta_w13 = torch.zeros_like(intermediate_cache1) + + def _lora_w13_fn(): + return self.apply_w13_lora( + lora_context, + y=lora_delta_w13, + x=hidden_states, + topk_ids=topk_ids, + topk_weights=topk_weights, + expert_map=expert_map, + w1=w1, + w2=w2, + num_tokens=num_tokens, + top_k_num=top_k_num, + add_inputs=False, + ) + + assert lora_context.events is not None + _, lora_meta = maybe_execute_in_parallel( + _base_w13_fn, + _lora_w13_fn, + lora_context.events[0], + lora_context.events[1], + lora_context.aux_stream, + ) ( sorted_token_ids_lora, expert_ids_lora, num_tokens_post_padded_lora, token_lora_mapping, - ) = self.apply_w13_lora( - lora_context, - y=intermediate_cache1, - x=hidden_states, - topk_ids=topk_ids, - topk_weights=topk_weights, - expert_map=expert_map, - w1=w1, - w2=w2, - num_tokens=num_tokens, - top_k_num=top_k_num, - ) + ) = lora_meta + intermediate_cache1.add_(lora_delta_w13) + else: + _base_w13_fn() + if lora_context is not None: + ( + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + token_lora_mapping, + ) = self.apply_w13_lora( + lora_context, + y=intermediate_cache1, + x=hidden_states, + topk_ids=topk_ids, + topk_weights=topk_weights, + expert_map=expert_map, + w1=w1, + w2=w2, + num_tokens=num_tokens, + top_k_num=top_k_num, + ) self.activation( activation, intermediate_cache2, intermediate_cache1.view(-1, N) @@ -2137,47 +2185,81 @@ def apply( quantization_emulation=self.quantization_emulation, ) - invoke_fused_moe_triton_kernel( - qintermediate_cache2, - w2, - intermediate_cache3, - a2q_scale, - self.w2_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - not apply_router_weight_on_input, - 1, - config, - compute_type=compute_type, - use_fp8_w8a8=self.quant_config.use_fp8_w8a8, - use_int8_w8a8=self.quant_config.use_int8_w8a8, - use_int8_w8a16=self.quant_config.use_int8_w8a16, - use_int4_w4a16=self.quant_config.use_int4_w4a16, - per_channel_quant=self.per_act_token_quant, - block_shape=self.block_shape, - B_bias=self.w2_bias, - ) - # LoRA w2: applied to intermediate_cache3 before moe_sum, using the # unquantized intermediate_cache2 as the lora_a input. Reuses the - # sorted_token_ids_lora computed above. - if lora_context is not None: - self.apply_w2_lora( - lora_context, - y=intermediate_cache3, - x=intermediate_cache2, - topk_weights=topk_weights, - sorted_token_ids_lora=sorted_token_ids_lora, - expert_ids_lora=expert_ids_lora, - num_tokens_post_padded_lora=num_tokens_post_padded_lora, - token_lora_mapping=token_lora_mapping, - num_tokens=num_tokens, - w1=w1, - w2=w2, - top_k_num=top_k_num, + # sorted_token_ids_lora computed above. Same dual-stream pattern as + # the w13 pair: base GEMM on default stream, LoRA delta on aux, + # join via .add_() into intermediate_cache3. + def _base_w2_fn(): + invoke_fused_moe_triton_kernel( + qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + self.w2_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + not apply_router_weight_on_input, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=self.quant_config.use_fp8_w8a8, + use_int8_w8a8=self.quant_config.use_int8_w8a8, + use_int8_w8a16=self.quant_config.use_int8_w8a16, + use_int4_w4a16=self.quant_config.use_int4_w4a16, + per_channel_quant=self.per_act_token_quant, + block_shape=self.block_shape, + B_bias=self.w2_bias, + ) + + if lora_context is not None and lora_context.aux_stream is not None: + lora_delta_w2 = torch.zeros_like(intermediate_cache3) + + def _lora_w2_fn(): + self.apply_w2_lora( + lora_context, + y=lora_delta_w2, + x=intermediate_cache2, + topk_weights=topk_weights, + sorted_token_ids_lora=sorted_token_ids_lora, + expert_ids_lora=expert_ids_lora, + num_tokens_post_padded_lora=num_tokens_post_padded_lora, + token_lora_mapping=token_lora_mapping, + num_tokens=num_tokens, + w1=w1, + w2=w2, + top_k_num=top_k_num, + add_inputs=False, + ) + + assert lora_context.events is not None + maybe_execute_in_parallel( + _base_w2_fn, + _lora_w2_fn, + lora_context.events[2], + lora_context.events[3], + lora_context.aux_stream, ) + intermediate_cache3.add_(lora_delta_w2) + else: + _base_w2_fn() + if lora_context is not None: + self.apply_w2_lora( + lora_context, + y=intermediate_cache3, + x=intermediate_cache2, + topk_weights=topk_weights, + sorted_token_ids_lora=sorted_token_ids_lora, + expert_ids_lora=expert_ids_lora, + num_tokens_post_padded_lora=num_tokens_post_padded_lora, + token_lora_mapping=token_lora_mapping, + num_tokens=num_tokens, + w1=w1, + w2=w2, + top_k_num=top_k_num, + ) # separate function is required for MoE + LoRA self.moe_sum(intermediate_cache3, output) diff --git a/vllm/model_executor/layers/fused_moe/lora_context.py b/vllm/model_executor/layers/fused_moe/lora_context.py index 92500a7bb47d..af26680585ef 100644 --- a/vllm/model_executor/layers/fused_moe/lora_context.py +++ b/vllm/model_executor/layers/fused_moe/lora_context.py @@ -42,3 +42,13 @@ class MoELoRAContext: # Whether VLLM_TUNED_CONFIG_FOLDER is set; selects get_lora_op_configs vs # try_get_optimal_moe_lora_config for Triton kernel tile configs. use_tuned_config: bool + + # Optional dual-stream support for overlapping each (base GEMM, LoRA) + # pair. When aux_stream is None, the experts.apply() path runs the + # original sequential schedule. When set, base GEMM runs on the default + # stream and the LoRA fast-path writes the delta into a fresh buffer on + # aux_stream, which the default stream sums in afterwards. + # Events are paired one-per-overlap-pair: events[0,1] for w13, + # events[2,3] for w2, so the two pairs do not race on the same event. + aux_stream: torch.cuda.Stream | None = None + events: tuple[torch.cuda.Event, ...] | None = None diff --git a/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py b/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py index c609c5cf56b5..18fe4a584e29 100644 --- a/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py +++ b/vllm/model_executor/layers/fused_moe/lora_experts_mixin.py @@ -45,6 +45,7 @@ def apply_w13_lora( w2: torch.Tensor, num_tokens: int, top_k_num: int, + add_inputs: bool = True, ) -> tuple[ torch.Tensor | None, torch.Tensor | None, @@ -70,6 +71,7 @@ def apply_w13_lora( lora_context.w13_num_slices, lora_context.fully_sharded, lora_context.use_tuned_config, + add_inputs=add_inputs, ) def apply_w2_lora( @@ -87,6 +89,7 @@ def apply_w2_lora( w1: torch.Tensor, w2: torch.Tensor, top_k_num: int, + add_inputs: bool = True, ) -> None: lora_context.punica_wrapper.add_lora_w2( y, @@ -108,4 +111,5 @@ def apply_w2_lora( lora_context.fully_sharded, lora_context.tp_rank, lora_context.use_tuned_config, + add_inputs=add_inputs, ) diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py index 2eee8acf6b8f..e55b4cfd2144 100644 --- a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py @@ -151,6 +151,8 @@ def _moe_forward_shared_fake( return shared_out, fused_out +# NOTE: `moe_forward` and `moe_forward_shared` being opaque custom ops is a +# load-bearing assumption for the MoE-LoRA dual-stream path. direct_register_custom_op( op_name="moe_forward", op_func=_moe_forward, From ce0e0c864b4563cd8b5839d99c403b6adab1aedc Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Fri, 8 May 2026 02:56:52 +0000 Subject: [PATCH 04/10] Add glm config Signed-off-by: Jee Jee Li --- .../benchmark_fused_moe_lora_one_shot.py | 45 +++++++++++++++---- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/benchmarks/kernels/benchmark_fused_moe_lora_one_shot.py b/benchmarks/kernels/benchmark_fused_moe_lora_one_shot.py index 6193bb9fa029..c5a592ea7e2b 100644 --- a/benchmarks/kernels/benchmark_fused_moe_lora_one_shot.py +++ b/benchmarks/kernels/benchmark_fused_moe_lora_one_shot.py @@ -19,6 +19,7 @@ from __future__ import annotations import argparse +import os import random import torch @@ -325,6 +326,17 @@ def _run_two_kernel(inp: dict): num_slices=2, block_size_m=64, ), + # GLM-5.1 (zai-org/GLM-5.1-FP8): E=256, top_k=8, hidden=6144, + # moe_intermediate=2048 + "glm5_1": dict( + K=6144, + N_per_slice=2048, + num_experts=256, + top_k=8, + max_loras=4, + num_slices=2, + block_size_m=64, + ), } @@ -332,8 +344,10 @@ def _run_two_kernel(inp: dict): RANK_RANGE = [8, 16, 32, 64] -def get_benchmark(model: str): - preset = MODEL_PRESETS[model] +def get_benchmark(model: str, max_loras: int | None = None): + preset = dict(MODEL_PRESETS[model]) + if max_loras is not None: + preset["max_loras"] = max_loras @triton.testing.perf_report( triton.testing.Benchmark( @@ -344,7 +358,7 @@ def get_benchmark(model: str): line_names=["one_shot (fused)", "two_kernel (legacy)"], styles=[("red", "-"), ("blue", "-")], ylabel="ms", - plot_name=f"fused_moe_lora-{model}", + plot_name=f"fused_moe_lora-{model}-loras{preset['max_loras']}", args={"preset": preset}, ) ) @@ -373,8 +387,10 @@ def benchmark(M, rank, provider, preset): # ----- correctness sanity --------------------------------------------------- -def calculate_diff(model: str, M: int, rank: int): - preset = MODEL_PRESETS[model] +def calculate_diff(model: str, M: int, rank: int, max_loras: int | None = None): + preset = dict(MODEL_PRESETS[model]) + if max_loras is not None: + preset["max_loras"] = max_loras inp = _make_inputs( M=M, K=preset["K"], @@ -423,14 +439,27 @@ def calculate_diff(model: str, M: int, rank: int): action="store_true", help="Run correctness sanity check only, no perf sweep", ) + parser.add_argument( + "--max-loras", + type=int, + default=None, + help="Override max_loras in the model preset (number of LoRA adapters " + "active in the batch). Defaults to the preset's value.", + ) args = parser.parse_args() print(f"Correctness check ({args.model}):") - calculate_diff(args.model, M=256, rank=32) + calculate_diff(args.model, M=256, rank=32, max_loras=args.max_loras) if args.check_only: raise SystemExit(0) + effective_max_loras = ( + args.max_loras + if args.max_loras is not None + else MODEL_PRESETS[args.model]["max_loras"] + ) print(f"\nGPU: {torch.cuda.get_device_name()}") - print(f"Model preset: {args.model}\n") - benchmark = get_benchmark(args.model) + print(f"Model preset: {args.model} max_loras={effective_max_loras}\n") + benchmark = get_benchmark(args.model, max_loras=args.max_loras) + os.makedirs(args.save_path, exist_ok=True) benchmark.run(print_data=True, save_path=args.save_path) From c787eb3b0a3132588f5d6b1d532f761f503d2054 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 11 May 2026 07:15:44 +0000 Subject: [PATCH 05/10] Support samll batch fused kernel Signed-off-by: Jee Jee Li --- tests/lora/test_fused_moe_lora_kernel.py | 5 +- vllm/envs.py | 8 + vllm/lora/layers/fused_moe.py | 8 +- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 383 +++++++++++++++++- .../layers/fused_moe/fused_moe.py | 2 +- 5 files changed, 399 insertions(+), 7 deletions(-) diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index d08159a0c85b..dfbb07b2dde9 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -1669,9 +1669,8 @@ def _run_dual_stream(out_buf: torch.Tensor) -> None: out_graph = residual.clone() capture_stream = torch.cuda.Stream() capture_stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(capture_stream): - with torch.cuda.graph(g, stream=capture_stream): - _run_dual_stream(out_graph) + with torch.cuda.stream(capture_stream), torch.cuda.graph(g, stream=capture_stream): + _run_dual_stream(out_graph) torch.cuda.current_stream().wait_stream(capture_stream) torch.cuda.synchronize() diff --git a/vllm/envs.py b/vllm/envs.py index 77c10918fd52..c512338c8e70 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -268,6 +268,7 @@ VLLM_XPU_ENABLE_XPU_GRAPH: bool = False VLLM_XPU_USE_SAMPLER_KERNEL: bool = True VLLM_LORA_ENABLE_DUAL_STREAM: bool = False + VLLM_LORA_USE_ONE_SHOT_MOE: bool = True def get_default_cache_root(): @@ -1790,6 +1791,13 @@ def _get_or_set_default() -> str: "VLLM_LORA_ENABLE_DUAL_STREAM": lambda: bool( int(os.getenv("VLLM_LORA_ENABLE_DUAL_STREAM", "0")) ), + # Whether to use the one-shot fused MoE LoRA kernel (combined shrink+expand). + # When disabled, falls back to the legacy two-kernel shrink/expand path. + # Dual-stream MoE LoRA depends on the one-shot kernel's add_inputs=False + # contract, so dual-stream is force-disabled when this is off. + "VLLM_LORA_USE_ONE_SHOT_MOE": lambda: bool( + int(os.getenv("VLLM_LORA_USE_ONE_SHOT_MOE", "1")) + ), } diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index ea3d1de31490..3d8da2741ad1 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -58,8 +58,12 @@ def __init__(self, base_layer: FusedMoE) -> None: # the linear-LoRA dual-stream path in # vllm/lora/layers/base_linear.py); enabling it for a deployment # turns dual-stream on for both linear and MoE LoRA layers in one - # switch. - self._enable_aux_cuda_stream = envs.VLLM_LORA_ENABLE_DUAL_STREAM + # switch. Dual-stream relies on the one-shot kernel's + # add_inputs=False contract, so it is force-disabled when the + # one-shot path is turned off via VLLM_LORA_USE_ONE_SHOT_MOE=0. + self._enable_aux_cuda_stream = ( + envs.VLLM_LORA_ENABLE_DUAL_STREAM and envs.VLLM_LORA_USE_ONE_SHOT_MOE + ) self._init_lora_stream_context() # For non-gated MoE (is_act_and_mul=False), only 1 slice is needed # since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 61303360363e..58f04aeae947 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -3,6 +3,7 @@ import torch +from vllm import envs from vllm.distributed import ( tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, @@ -481,6 +482,352 @@ def _run_fused_moe_lora_one_shot( ) +# --------------------------------------------------------------------------- +# Small-batch (decode-style) fused MoE-LoRA kernel — sub-path of the +# one_shot fast path. +# --------------------------------------------------------------------------- + + +@triton.heuristics({"EVEN_K": lambda args: args["K"] % args["BLOCK_K"] == 0}) +@triton.jit +def _fused_moe_lora_small_batch_kernel( + # ---- pointers ---- + x_ptr, + A_ptrs, + B_ptrs, + out_ptr, + topk_weights_ptr, + expert_ids_ptr, # (num_tokens * top_k_num,) + token_lora_mapping_ptr, # (num_tokens,) + adapter_enabled_ptr, + # ---- dims ---- + N, + K, + top_k_num, + max_loras, + work_total, # = pair_slices * n_chunks_per_pair_slice + pair_slices, # = num_tokens * top_k_num * NUM_SLICES + # ---- strides ---- + stride_xm, + stride_xk, + stride_A_lora, + stride_A_expert, + stride_A_r, + stride_A_k, + stride_B_lora, + stride_B_expert, + stride_B_n, + stride_B_r, + stride_om, + stride_on, + # ---- scalar (runtime ints, NOT constexpr) ---- + # n_tiles_per_program / n_chunks_per_pair_slice are deliberately + # runtime: each distinct value would otherwise trigger a fresh Triton + # compile -> fresh kernel binary -> fresh CUDA graph instance per + # batch size. Production traces showed that variant explosion adding + # ~5.9k graph instantiations on top of legacy. Runtime args mean one + # shared binary across all chunk sizes. + slice_n_offset, + n_tiles_per_program, + n_chunks_per_pair_slice, + # ---- constexpr ---- + token_mapping_factor: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + ADD_INPUTS: tl.constexpr, + BLOCK_R: tl.constexpr, + actual_rank: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + NUM_SLICES: tl.constexpr, + EVEN_K: tl.constexpr, +): + """Persistent fused MoE-LoRA kernel for naive_block_assignment inputs. + + Each program owns one (pair × slice × n_chunk) work item. A "chunk" + covers `n_tiles_per_program` consecutive output-N tiles, all of which + share a single shrink — so the rank-vector is computed once per + program and the A weights for that (lora, expert, slice) are loaded + once instead of n_tiles_per_program times. + + The wrapper picks `n_tiles_per_program` to keep the grid close to + 2*SM_count: at very small batch (work_total ≤ SM_count) the chunk + size collapses to 1 and behaviour matches a per-tile GEMV; as batch + grows the chunk grows so we trade some N-axis parallelism for shrink + reuse. When `work_total` exceeds the launched grid, the outer stride + loop drains the leftover work units serially. + """ + pid = tl.program_id(axis=0) + num_programs = tl.num_programs(axis=0) + + offs_r = tl.arange(0, BLOCK_R) + rank_mask = offs_r < actual_rank + # Clamp OOB rank lanes so they address row 0 of A/B; the mask zeros + # the loaded values. Required when BLOCK_R > actual_rank (e.g. rank=4 + # padded to 16) -- without clamping, tl.load would address the next + # expert's memory. + safe_offs_r = tl.where(rank_mask, offs_r, 0) + offs_k = tl.arange(0, BLOCK_K) + + # Persistent stride loop: when grid < work_total each program walks + # multiple work items. When grid == work_total the loop runs exactly + # once and the kernel degenerates to the per-tile GEMV. + for work_id in range(pid, work_total, num_programs): + n_chunk_idx = work_id % n_chunks_per_pair_slice + pair_slice_idx = work_id // n_chunks_per_pair_slice + # NUM_SLICES is constexpr (typ. 1 or 2) so divmod folds. + pair_idx = pair_slice_idx // NUM_SLICES + slice_id = pair_slice_idx % NUM_SLICES + + # Resolve lora_id / expert_id; skip the body for inactive lanes. + # Using a single `valid` flag instead of early `return` keeps the + # outer stride loop alive — `return` would exit the whole program + # and skip later work items assigned to this SM. + token_idx = pair_idx // top_k_num + lora_id = tl.load(token_lora_mapping_ptr + token_idx) + valid = (lora_id >= 0) & (lora_id < max_loras) + enabled = tl.load(adapter_enabled_ptr + tl.where(valid, lora_id, 0)) + valid = valid & (enabled != 0) + expert_id = tl.load(expert_ids_ptr + pair_idx) + valid = valid & (expert_id >= 0) + + if valid: + cur_A_ptr = tl.load(A_ptrs + slice_id).to( + tl.pointer_type(out_ptr.dtype.element_ty) + ) + cur_B_ptr = tl.load(B_ptrs + slice_id).to( + tl.pointer_type(out_ptr.dtype.element_ty) + ) + A_base = cur_A_ptr + lora_id * stride_A_lora + expert_id * stride_A_expert + B_base = cur_B_ptr + lora_id * stride_B_lora + expert_id * stride_B_expert + + x_row = pair_idx // token_mapping_factor + x_row_ptr = x_ptr + x_row * stride_xm + + # SHRINK GEMV (once per program; reused across n_tiles_per_program + # expand tiles below). Sum-reduction over BLOCK_K with fp32 + # accumulator — same precision path as the one_shot kernel. + rank_vec = tl.zeros((BLOCK_R,), dtype=tl.float32) + if EVEN_K: + for kb in range(0, K, BLOCK_K): + cur_k = kb + offs_k + x_tile = tl.load(x_row_ptr + cur_k * stride_xk).to(tl.float32) + a_tile = tl.load( + A_base + + safe_offs_r[:, None] * stride_A_r + + cur_k[None, :] * stride_A_k, + mask=rank_mask[:, None], + other=0.0, + ).to(tl.float32) + rank_vec += tl.sum(a_tile * x_tile[None, :], axis=1) + else: + for kb in range(0, K, BLOCK_K): + cur_k = kb + offs_k + k_mask = cur_k < K + x_tile = tl.load( + x_row_ptr + cur_k * stride_xk, mask=k_mask, other=0.0 + ).to(tl.float32) + a_tile = tl.load( + A_base + + safe_offs_r[:, None] * stride_A_r + + cur_k[None, :] * stride_A_k, + mask=rank_mask[:, None] & k_mask[None, :], + other=0.0, + ).to(tl.float32) + rank_vec += tl.sum(a_tile * x_tile[None, :], axis=1) + + # EXPAND: walk n_tiles_per_program consecutive output-N tiles + # using the same rank_vec. The loop is a runtime range (not + # tl.static_range) so a single compiled kernel handles every + # chunk size — see the note on the kernel signature. + n_tile_start = n_chunk_idx * n_tiles_per_program + out_row_ptr = out_ptr + slice_id * slice_n_offset + pair_idx * stride_om + + if MUL_ROUTED_WEIGHT: + moe_w = tl.load(topk_weights_ptr + pair_idx).to(tl.float32) + + for nt in range(n_tiles_per_program): + n_lo = (n_tile_start + nt) * BLOCK_N + if n_lo < N: + offs_n = n_lo + tl.arange(0, BLOCK_N) + n_mask = offs_n < N + b_tile = tl.load( + B_base + + offs_n[:, None] * stride_B_n + + safe_offs_r[None, :] * stride_B_r, + mask=n_mask[:, None] & rank_mask[None, :], + other=0.0, + ).to(tl.float32) + out_tile = tl.sum(b_tile * rank_vec[None, :], axis=1) + + if MUL_ROUTED_WEIGHT: + out_tile = out_tile * moe_w + + out_ptrs = out_row_ptr + offs_n * stride_on + if ADD_INPUTS: + prev = tl.load(out_ptrs, mask=n_mask, other=0.0).to(tl.float32) + tl.store( + out_ptrs, + (prev + out_tile).to(out_ptr.dtype.element_ty), + mask=n_mask, + ) + else: + tl.store( + out_ptrs, + out_tile.to(out_ptr.dtype.element_ty), + mask=n_mask, + ) + + +def _pick_small_batch_chunk(pair_slices: int, N_tiles: int, sm_count: int) -> int: + """Pick `n_tiles_per_program` so the launched grid stays near + 2*SM_count. + + Sizes for occupancy first (more programs in flight → better latency + hiding for the K-loop A/x loads). Once the per-tile grid already + exceeds 2*SM_count we increase the chunk size to amortise the shrink + cost — at that point the GPU is saturated by per-program work and + packing more tiles per program lets the rank_vec be reused. + """ + target_grid = max(1, 2 * sm_count) + total_work = pair_slices * N_tiles + if total_work <= target_grid: + return 1 + ntpp = (total_work + target_grid - 1) // target_grid + return min(ntpp, N_tiles) + + +def _run_fused_moe_lora_small_batch( + output: torch.Tensor, + qcurr_hidden_states: torch.Tensor, + lora_a_stacked: list[torch.Tensor], + lora_b_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + expert_ids_flat: torch.Tensor, # (num_tokens * top_k_num,) + token_lora_mapping: torch.Tensor, + top_k_num: int, + adapter_enabled: torch.Tensor, + mul_routed_weight: bool, + add_inputs: bool = True, +) -> None: + """Small-batch GEMV-style wrapper. Naive-block-assignment inputs only. + + Shape contract matches `_run_fused_moe_lora_one_shot`: `output` is + `(num_tokens, top_k_num, num_slices * N_per_slice)` with + contiguous-style strides, `expert_ids_flat` is the flattened + `topk_ids` of shape `(num_tokens * top_k_num,)`, and the + rank-padded LoRA weights live in `lora_a_stacked` / + `lora_b_stacked`. + + The kernel is persistent over (pair × slice × n_chunk) work items — + each program does one shrink and reuses the rank vector across + `n_tiles_per_program` expand tiles. The chunk size scales with the + pair-slice count so very small batches keep per-tile parallelism + while medium batches cut redundant shrinks. + """ + num_slices = len(lora_a_stacked) + device = qcurr_hidden_states.device + + A0 = lora_a_stacked[0] + B0 = lora_b_stacked[0] + max_loras_w = A0.shape[0] + rank = A0.shape[2] + K = A0.shape[3] + N_per_slice = B0.shape[2] + + # Rank padding: floor 16 (tensor-core min K), ceil to next pow2. The + # ≤64 cap is set conservatively for the prototype: at rank 64 the + # per-program register footprint is rank_vec(64 fp32) + b_tile(BLOCK_N + # × 64 fp32) = e.g. 128*64*4 = 32 KiB, comfortably within the 64 KiB + # register file even with num_warps=8. Doubling to 128 would push us + # against the limit and require shared-memory staging. + assert rank <= 64, f"fused_moe_lora_small_batch supports rank<=64; got rank={rank}" + BLOCK_R = max(triton.next_power_of_2(rank), 16) + + num_tokens = topk_weights.shape[0] + M_grid = num_tokens * top_k_num + if M_grid == 0 or num_slices == 0: + return + + token_mapping_factor = 1 if mul_routed_weight else top_k_num + + A_ptrs = _get_ptr(lora_a_stacked, device) + B_ptrs = _get_ptr(lora_b_stacked, device) + + assert output.dim() == 3, f"output must be 3-D, got {output.shape}" + assert output.stride(0) == output.shape[1] * output.stride(1), ( + "fused_moe_lora_small_batch requires output.stride(0) == " + f"top_k*stride(1); got shape={output.shape} strides={output.stride()}" + ) + out_view = output.view(-1, output.shape[-1]) + + # Block sizes. BLOCK_N=128 matches the one_shot's expand tile and gives + # 6-24 N tiles for typical N ∈ [768, 3072], enough to saturate the SM + # array once M_grid * num_slices reaches ~SM_count. BLOCK_K=128 halves + # the K-loop trip count vs 64 and pays for itself once K ≥ 1024 (the + # only regime we care about — hidden sizes are always large here). + BLOCK_N = 128 + BLOCK_K = 128 + nw = 4 + ns = 3 + + N_tiles = triton.cdiv(N_per_slice, BLOCK_N) + pair_slices = M_grid * num_slices + + sm_count = torch.cuda.get_device_properties(device).multi_processor_count + n_tiles_per_program = _pick_small_batch_chunk(pair_slices, N_tiles, sm_count) + n_chunks = triton.cdiv(N_tiles, n_tiles_per_program) + work_total = pair_slices * n_chunks + + # Grid sizing: keep parallelism uncapped when work_total is small (so + # very small batches still spread across SMs); cap at 2*SM_count once + # we have plenty of work, letting the in-kernel stride loop drain the + # remainder. + grid_size = min(work_total, max(1, 2 * sm_count)) + grid = (grid_size,) + + _fused_moe_lora_small_batch_kernel[grid]( + qcurr_hidden_states, + A_ptrs, + B_ptrs, + out_view, + topk_weights, + expert_ids_flat, + token_lora_mapping, + adapter_enabled, + N_per_slice, + K, + top_k_num, + max_loras_w, + work_total, + pair_slices, + qcurr_hidden_states.stride(0), + qcurr_hidden_states.stride(1), + A0.stride(0), + A0.stride(1), + A0.stride(2), + A0.stride(3), + B0.stride(0), + B0.stride(1), + B0.stride(2), + B0.stride(3), + out_view.stride(0), + out_view.stride(1), + N_per_slice, + n_tiles_per_program, + n_chunks, + token_mapping_factor=token_mapping_factor, + MUL_ROUTED_WEIGHT=mul_routed_weight, + ADD_INPUTS=add_inputs, + BLOCK_R=BLOCK_R, + actual_rank=rank, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + NUM_SLICES=num_slices, + num_warps=nw, + num_stages=ns, + ) + + def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device): """ `_LORA_PTR_DICT` collects the required information during `profile_run`, @@ -1111,7 +1458,41 @@ def _fused_moe_lora( # implementation. fully_sharded=True still needs the materialised # intermediate cache so that an all_reduce / all_gather can flow # between shrink and expand, so it falls through to the legacy path. - if not fully_sharded: + # VLLM_LORA_USE_ONE_SHOT_MOE=0 also forces the legacy path for + # debugging / benchmarking. + if not fully_sharded and envs.VLLM_LORA_USE_ONE_SHOT_MOE: + # Inside the one_shot fast path we further split between two + # kernels: + # * small-batch persistent GEMV — when the caller picked + # naive_block_assignment (sorted_token_ids is None — happens + # whenever num_tokens*top_k is sparse vs num_experts*max_loras, + # see SPARSITY_FACTOR in punica_gpu.add_lora_fused_moe), AND + # M_pairs * rank ≤ 1024 (cutoff from a GB200 sweep over ranks + # {16,32,64} — below this, the persistent GEMV path is + # 1.0-1.7x faster than the one_shot GEMM tile kernel). + # * one_shot GEMM tile kernel — everything else (prefill / large + # batch). Both are "fused" in that shrink+expand stay in + # registers; they differ only in tiling strategy. + M_pairs = topk_weights.numel() + if ( + sorted_token_ids is None + and max_lora_rank <= 64 + and M_pairs * max_lora_rank <= 1024 + ): + _run_fused_moe_lora_small_batch( + output, + qcurr_hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + expert_ids, + token_lora_mapping, + top_k_num, + adapter_enabled, + mul_routed_weight, + add_inputs=add_inputs, + ) + return # shrink/expand BLOCK_SIZE_M must match the block_size that # moe_lora_align_block_size used; both shrink and expand pass the # same value (asserted by `shrink_block_size_m == expand_block_size_m` diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 055a3a15f380..d8a9142e6981 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -2083,7 +2083,7 @@ def apply( # # Note on torch.compile : # The whole MoE forward is already wrapped in torch.ops.vllm.moe_forward`, - # so we don't need to wrapp the following code as custom op + # so we don't need to wrap the following code as custom op sorted_token_ids_lora = None expert_ids_lora = None num_tokens_post_padded_lora = None From 9b1b13225184c9f8d55921252560d34596a2dd60 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 12 May 2026 17:29:54 +0000 Subject: [PATCH 06/10] Move Signed-off-by: Jee Jee Li --- vllm/lora/layers/base_linear.py | 9 +---- vllm/lora/layers/fused_moe.py | 34 ++----------------- vllm/lora/layers/utils.py | 17 ++++++++++ .../layers/fused_moe/experts/triton_moe.py | 4 +-- 4 files changed, 21 insertions(+), 43 deletions(-) diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py index f4662bfa8e64..3168d25fe477 100644 --- a/vllm/lora/layers/base_linear.py +++ b/vllm/lora/layers/base_linear.py @@ -25,16 +25,9 @@ from vllm.utils.torch_utils import direct_register_custom_op from .base import BaseLayerWithLoRA -from .utils import _get_lora_device +from .utils import _get_lora_aux_cuda_stream, _get_lora_device if envs.VLLM_LORA_ENABLE_DUAL_STREAM: - _lora_aux_cuda_stream: torch.cuda.Stream | None = None - - def _get_lora_aux_cuda_stream() -> torch.cuda.Stream | None: - global _lora_aux_cuda_stream - if _lora_aux_cuda_stream is None and current_platform.is_cuda_alike(): - _lora_aux_cuda_stream = torch.cuda.Stream() - return _lora_aux_cuda_stream def lora_linear_async( layer_name: str, diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 01843434473f..186bb896e024 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -20,20 +20,7 @@ ) from vllm.platforms import current_platform -from .utils import _get_lora_device - -# Process-wide singleton aux stream shared by every FusedMoEWithLoRA instance. -# Mirrors the pattern in vllm/lora/layers/base_linear.py: one extra stream is -# enough to overlap two compute streams; allocating one per layer would -# under-utilise the SMs and inflate context-switch cost. -_moe_lora_aux_cuda_stream: torch.cuda.Stream | None = None - - -def _get_moe_lora_aux_cuda_stream() -> torch.cuda.Stream | None: - global _moe_lora_aux_cuda_stream - if _moe_lora_aux_cuda_stream is None and current_platform.is_cuda_alike(): - _moe_lora_aux_cuda_stream = torch.cuda.Stream() - return _moe_lora_aux_cuda_stream +from .utils import _get_lora_aux_cuda_stream, _get_lora_device class FusedMoEWithLoRA(BaseLayerWithLoRA): @@ -101,7 +88,7 @@ def _init_lora_stream_context(self) -> None: return if not current_platform.is_cuda_alike(): return - self._lora_stream = _get_moe_lora_aux_cuda_stream() + self._lora_stream = _get_lora_aux_cuda_stream() # 4 events: 2 per (base GEMM, LoRA) pair so w13 and w2 don't reuse # the same event objects; reuse-within-a-pair is fine because the # second pair starts only after intermediate_cache1.add_() has joined. @@ -421,23 +408,6 @@ def runner(self): def is_internal_router(self) -> bool: return self.base_layer.is_internal_router - # TEMP(unblock-end-to-end): Upstream introduced FusedMoE.runner (see - # vllm/model_executor/layers/fused_moe/runner/), which holds gates and - # other linear submodules that LoRA wants to wrap. The LoRA module - # walker (vllm/lora/model_manager.py::_create_lora_modules) iterates - # the model BEFORE wrapping, then calls - # nn.Module.get_submodule("...experts.runner.X") — which uses - # hasattr/getattr — to attach LoRA wrappers. Once "experts" has been - # replaced by FusedMoEWithLoRA, that lookup fails because runner is - # only on base_layer. Forwarding via @property is enough to satisfy - # get_submodule without registering runner as our own child module - # (which would double-count parameters / state_dict entries). - # TODO(remove once upstream LoRA walker handles wrapped modules - # natively). - @property - def runner(self): - return self.base_layer.runner - @classmethod def can_replace_layer( cls, diff --git a/vllm/lora/layers/utils.py b/vllm/lora/layers/utils.py index 1b8083f5c4d1..40d46ac9d977 100644 --- a/vllm/lora/layers/utils.py +++ b/vllm/lora/layers/utils.py @@ -7,9 +7,26 @@ import torch import torch.nn as nn +from vllm import envs from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config +from vllm.platforms import current_platform from vllm.utils.math_utils import next_power_of_2 +_lora_aux_cuda_stream: torch.cuda.Stream | None = None + + +def _get_lora_aux_cuda_stream() -> torch.cuda.Stream | None: + # Gate stream creation on the dual-stream master switch so a stray call + # from a future code path cannot silently allocate a CUDA stream when the + # feature is turned off. MoE LoRA layers an additional VLLM_LORA_USE_ONE_SHOT_MOE + # gate at their call site. + if not envs.VLLM_LORA_ENABLE_DUAL_STREAM: + return None + global _lora_aux_cuda_stream + if _lora_aux_cuda_stream is None and current_platform.is_cuda_alike(): + _lora_aux_cuda_stream = torch.cuda.Stream() + return _lora_aux_cuda_stream + class LoRAMappingType(Enum): LANGUAGE = 1 diff --git a/vllm/model_executor/layers/fused_moe/experts/triton_moe.py b/vllm/model_executor/layers/fused_moe/experts/triton_moe.py index 4716317d4de7..a96d178a3fe1 100644 --- a/vllm/model_executor/layers/fused_moe/experts/triton_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/triton_moe.py @@ -240,9 +240,7 @@ def apply( # the LoRA writes its delta into a fresh zero buffer (add_inputs= # False) and we sum it into intermediate_cache1 after both finish. # - # Note on torch.compile : - # The whole MoE forward is already wrapped in torch.ops.vllm.moe_forward`, - # so we don't need to wrap the following code as custom op + sorted_token_ids_lora = None expert_ids_lora = None num_tokens_post_padded_lora = None From b1d60d975c98d05b1480731ff6d80bd42af4ce95 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Thu, 14 May 2026 17:25:37 +0000 Subject: [PATCH 07/10] Move Signed-off-by: Jee Jee Li --- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 32 ++++++++++++------- 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 58f04aeae947..1638e30b84fd 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -417,21 +417,29 @@ def _run_fused_moe_lora_one_shot( npid = max(1, min(npid, max(1, N_per_slice // 128))) # Robust defaults across the prefill regime (H100/H200/B200, bf16/fp16). - # NPID > 1 is the small-M / under-saturated path -- more warps + a - # deeper pipeline help amortise the inner-N expand loop. + # NPID > 1 is the small-M / under-saturated path -- more warps help + # amortise the inner-N expand loop. ns=3 instead of 4: GB200 ncu showed + # the 4-stage pipeline pushed register count to 168/thread and capped + # achieved occupancy at ~17% (3 blocks/SM, register-bound); ns=3 frees + # ~30 regs/thread which keeps a 4th block resident on small grids. if npid > 1: - block_n, nw, ns = 128, 8, 4 + block_n, nw, ns = 128, 8, 3 else: block_n, nw, ns = 128, 4, 3 - # BLOCK_K choice: when each expert sees enough tokens, the wider K tile - # halves the K-loop trip count and amortises load/MMA setup -- worth ~5 - # to 10% on GB200 for prefill-sized inputs. For sparsely-populated - # routings (e.g. M=16 mixtral, ~4 tokens/expert) the wider tile inflates - # per-program startup and most blocks early-exit anyway, so we keep the - # narrower tile. Threshold derived from GB200 sweeps over mixtral - # (E=8) and qwen3moe (E=64). - work_per_expert = topk_weights.numel() / max(num_experts, 1) - block_k = 128 if work_per_expert >= 16 else 64 + # BLOCK_K choice: for hidden-sized K (≥256, i.e. the K=hidden_size + # shrink input on w13) force BLOCK_K=128 -- the wider tile halves the + # K-loop trip count and removes the scoreboard stalls that dominated + # M=16-64 on GB200 (kernel time -13% to -37% vs the work_per_expert + # heuristic which picked 64 for low-tokens-per-expert ratios). For + # small-K shapes (e.g. w2 with K=192 where the down-proj reads the + # MoE intermediate) keep the work_per_expert heuristic: BLOCK_K=128 + # would force the EVEN_K=False masked path and add no K-loop savings + # (K/64=3 vs K/128=2 masked) while inflating per-program startup. + if K >= 256: + block_k = 128 + else: + work_per_expert = topk_weights.numel() / max(num_experts, 1) + block_k = 128 if work_per_expert >= 16 else 64 grid = (M_blocks * npid, num_slices, grid_lora_dim) From e4ec173e20448ded954bc43f79ca4b9808b93fd7 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Sat, 16 May 2026 17:19:29 +0000 Subject: [PATCH 08/10] Shrink tests Signed-off-by: Jee Jee Li --- tests/lora/test_fused_moe_lora_kernel.py | 568 ++---------------- vllm/lora/layers/base_linear.py | 15 +- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 8 +- 3 files changed, 52 insertions(+), 539 deletions(-) diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index 286a56e8308c..a70c5434736f 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -1034,164 +1034,79 @@ def _call_one_shot( ) +@pytest.mark.parametrize( + "trigger", + ["sorted_lora_ids_neg", "naive_mapping_neg", "naive_all_disabled"], +) @pytest.mark.parametrize("device", DEVICES) -def test_fused_moe_lora_kernel_no_active_loras(device): - """SORTED path: all entries in lora_ids are -1 -> every program must - early-exit at the lora_id<0 check and output must remain at its - pre-call residual value byte-for-byte.""" +def test_fused_moe_lora_kernel_one_shot_early_exit(trigger, device): + """one-shot must leave the residual byte-identical when every program + must early-exit. Three trigger conditions are covered: + + - "sorted_lora_ids_neg": sorted path, lora_ids all -1 (lora_id<0 check) + - "naive_mapping_neg": naive path, token_lora_mapping all -1 + - "naive_all_disabled": naive path, adapter_enabled all 0 + """ torch.set_default_device(device) set_random_seed(0) - num_tokens, top_k, E, max_loras, R, K, N = 32, 2, 8, 4, 16, 1024, 1024 + + # Per-trigger shapes: naive_mapping_neg needs the naive dispatch gate + # `num_tokens*top_k*8 <= num_experts*max_loras` to hold, hence the + # larger E/max_loras and smaller num_tokens. + if trigger == "naive_mapping_neg": + num_tokens, top_k, E, max_loras, R = 8, 2, 64, 8, 16 + elif trigger == "naive_all_disabled": + num_tokens, top_k, E, max_loras, R = 32, 2, 8, 4, 32 + else: # sorted_lora_ids_neg + num_tokens, top_k, E, max_loras, R = 32, 2, 8, 4, 16 + K, N = 1024, 1024 block_size, num_slices, dtype = 16, 2, torch.bfloat16 ( topk_ids, topk_weights, token_lora_mapping, - _, + lora_ids, lora_a_stacked, lora_b_stacked, hidden_states, ) = _build_one_shot_inputs( - num_tokens, - top_k, - E, - max_loras, - R, - K, - N, - num_slices, - block_size, - dtype, - ) - - lora_ids = torch.full((max_loras + 1,), -1, dtype=torch.int32) - adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) - num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") - residual = torch.randn((num_tokens, top_k, N), dtype=dtype) * 0.1 - output = residual.clone() - - # Sorted path with empty alignment: build minimal valid metadata. - max_pad = topk_ids.numel() + E * (block_size - 1) - max_pad = round_up(max_pad, block_size) - max_blocks = CEILDIV(max_pad, block_size) - sorted_token_ids = torch.zeros((max_loras, max_pad), dtype=torch.int32) - expert_ids = torch.full((max_loras, max_blocks), -1, dtype=torch.int32) - num_post = torch.zeros((max_loras,), dtype=torch.int32) - _call_one_shot( - output, - hidden_states, - lora_a_stacked, - lora_b_stacked, - topk_weights, - sorted_token_ids, - expert_ids, - num_post, - token_lora_mapping, - R, - top_k, - lora_ids, - num_active_loras, - adapter_enabled, - block_size, + num_tokens, top_k, E, max_loras, R, K, N, num_slices, block_size, dtype ) - torch.testing.assert_close(output, residual, atol=0, rtol=0) - - -@pytest.mark.parametrize("device", DEVICES) -def test_fused_moe_lora_kernel_naive_no_lora_tokens(device): - """NAIVE path: token_lora_mapping is all -1 -> lora_id resolves to -1 - in the kernel and every program early-exits. Residual preserved.""" - torch.set_default_device(device) - set_random_seed(0) - num_tokens, top_k, E, max_loras, R, K, N = 8, 2, 64, 8, 16, 1024, 1024 - num_slices, dtype = 2, torch.bfloat16 - (topk_ids, topk_weights, _, _, lora_a_stacked, lora_b_stacked, hidden_states) = ( - _build_one_shot_inputs( - num_tokens, - top_k, - E, - max_loras, - R, - K, - N, - num_slices, - 16, - dtype, - ) - ) - token_lora_mapping = torch.full((num_tokens,), -1, dtype=torch.int32) - lora_ids = torch.full((max_loras + 1,), -1, dtype=torch.int32) adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") - residual = torch.randn((num_tokens, top_k, N), dtype=dtype) * 0.1 - output = residual.clone() - expert_ids = topk_ids.reshape(-1).contiguous() - _call_one_shot( - output, - hidden_states, - lora_a_stacked, - lora_b_stacked, - topk_weights, - None, - expert_ids, - None, - token_lora_mapping, - R, - top_k, - lora_ids, - num_active_loras, - adapter_enabled, - 16, - ) - torch.testing.assert_close(output, residual, atol=0, rtol=0) - -@pytest.mark.parametrize("device", DEVICES) -def test_fused_moe_lora_kernel_all_disabled(device): - """adapter_enabled is all-zero: every program must early-exit at the - enabled check; residual preserved.""" - torch.set_default_device(device) - set_random_seed(0) - num_tokens, top_k, E, max_loras, R, K, N = 32, 2, 8, 4, 32, 1024, 1024 - block_size, num_slices, dtype = 16, 2, torch.bfloat16 + if trigger == "sorted_lora_ids_neg": + lora_ids = torch.full((max_loras + 1,), -1, dtype=torch.int32) + max_pad = topk_ids.numel() + E * (block_size - 1) + max_pad = round_up(max_pad, block_size) + max_blocks = CEILDIV(max_pad, block_size) + sorted_token_ids = torch.zeros((max_loras, max_pad), dtype=torch.int32) + expert_ids = torch.full((max_loras, max_blocks), -1, dtype=torch.int32) + num_post = torch.zeros((max_loras,), dtype=torch.int32) + else: + sorted_token_ids = None + expert_ids = topk_ids.reshape(-1).contiguous() + num_post = None + if trigger == "naive_mapping_neg": + token_lora_mapping = torch.full((num_tokens,), -1, dtype=torch.int32) + lora_ids = torch.full((max_loras + 1,), -1, dtype=torch.int32) + else: # naive_all_disabled + adapter_enabled = torch.zeros(max_loras + 1, dtype=torch.int32) - ( - topk_ids, - topk_weights, - token_lora_mapping, - lora_ids, - lora_a_stacked, - lora_b_stacked, - hidden_states, - ) = _build_one_shot_inputs( - num_tokens, - top_k, - E, - max_loras, - R, - K, - N, - num_slices, - block_size, - dtype, - ) - adapter_enabled = torch.zeros(max_loras + 1, dtype=torch.int32) - num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") residual = torch.randn((num_tokens, top_k, N), dtype=dtype) * 0.1 output = residual.clone() - expert_ids = topk_ids.reshape(-1).contiguous() _call_one_shot( output, hidden_states, lora_a_stacked, lora_b_stacked, topk_weights, - None, + sorted_token_ids, expert_ids, - None, + num_post, token_lora_mapping, R, top_k, @@ -1325,398 +1240,3 @@ def test_fused_moe_lora_kernel_rejects_bad_block_size_m(device): adapter_enabled, block_size, ) - - -@pytest.mark.parametrize("naive", [True, False]) -@pytest.mark.parametrize("num_slices", [1, 2]) -@pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("device", DEVICES) -def test_fused_moe_lora_kernel_add_inputs_parity(naive, num_slices, dtype, device): - """Prerequisite for the dual-stream LoRA path: add_inputs=False must - produce the LoRA delta only (matching what add_inputs=True would have - added on top of a residual). Concretely: - - run(add_inputs=True, output=residual) == residual - + run(add_inputs=False, - output=zeros) - - Covers both the SORTED path and the NAIVE block-assignment path.""" - torch.set_default_device(device) - set_random_seed(0) - - if naive: - # Naive path is gated by num_tokens * top_k * 8 <= num_experts * max_loras. - num_tokens, top_k, E, max_loras, R, K, N = 4, 2, 64, 8, 16, 1024, 512 - else: - num_tokens, top_k, E, max_loras, R, K, N = 32, 2, 8, 4, 16, 1024, 512 - block_size = 16 - - ( - topk_ids, - topk_weights, - token_lora_mapping, - lora_ids, - lora_a_stacked, - lora_b_stacked, - hidden_states, - ) = _build_one_shot_inputs( - num_tokens, top_k, E, max_loras, R, K, N, num_slices, block_size, dtype - ) - - adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) - num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") - - if naive: - sorted_token_ids = None - expert_ids = topk_ids.reshape(-1).contiguous() - num_tokens_post_padded = None - else: - max_pad = topk_ids.numel() + E * (block_size - 1) - max_pad = round_up(max_pad, block_size) - max_blocks = CEILDIV(max_pad, block_size) - sorted_token_ids = torch.empty((max_loras * max_pad,), dtype=torch.int32) - expert_ids = torch.empty((max_loras * max_blocks,), dtype=torch.int32) - num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32) - ops.moe_lora_align_block_size( - topk_ids, - token_lora_mapping, - E, - block_size, - max_loras, - max_pad, - max_blocks, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - adapter_enabled, - lora_ids, - ) - sorted_token_ids = sorted_token_ids.view(max_loras, -1) - expert_ids = expert_ids.view(max_loras, -1) - - # Use a non-trivial residual so the in-place add semantics are exercised - # (a zero residual would make add_inputs=True and add_inputs=False - # trivially equal even if the kernel ignored the flag). - residual = torch.randn((num_tokens, top_k, num_slices * N), dtype=dtype) * 0.1 - - # Path A: run with add_inputs=True on top of the residual. - out_inplace = residual.clone() - _call_one_shot( - out_inplace, - hidden_states, - lora_a_stacked, - lora_b_stacked, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - token_lora_mapping, - R, - top_k, - lora_ids, - num_active_loras, - adapter_enabled, - block_size, - add_inputs=True, - ) - - # Path B: run with add_inputs=False into a zero buffer to get the delta only. - out_delta = torch.zeros((num_tokens, top_k, num_slices * N), dtype=dtype) - _call_one_shot( - out_delta, - hidden_states, - lora_a_stacked, - lora_b_stacked, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - token_lora_mapping, - R, - top_k, - lora_ids, - num_active_loras, - adapter_enabled, - block_size, - add_inputs=False, - ) - - # The kernel writes only into rows touched by some (lora_id, expert_id) - # tile; rows that no program owns retain their pre-call value. So an - # "untouched residual + 0" row in path A must equal "0 + untouched 0" - # in path B. The equality below holds for both touched and untouched - # rows: residual + delta_only == residual_in_inplace_run. - torch.testing.assert_close(out_inplace, residual + out_delta, atol=1e-2, rtol=1e-2) - - -@pytest.mark.parametrize("device", DEVICES) -def test_fused_moe_lora_kernel_dual_stream_smoke(device): - """End-to-end smoke for the dual-stream pattern used by - FusedMoEWithLoRA: run the LoRA fast path on an aux CUDA stream with - add_inputs=False, then sum into a residual on the default stream - via a cuda.Event-synchronised .add_(). The result must match the - in-place add_inputs=True path on the default stream. - - This complements test_fused_moe_lora_kernel_add_inputs_parity by - exercising the actual stream-coordination machinery (event.record / - event.wait) that TritonExperts.apply uses; the parity test only - proves the math. - """ - torch.set_default_device(device) - set_random_seed(0) - - num_tokens, top_k, E, max_loras, R, K, N = 32, 2, 8, 4, 16, 1024, 512 - block_size, num_slices, dtype = 16, 2, torch.bfloat16 - - ( - topk_ids, - topk_weights, - token_lora_mapping, - lora_ids, - lora_a_stacked, - lora_b_stacked, - hidden_states, - ) = _build_one_shot_inputs( - num_tokens, top_k, E, max_loras, R, K, N, num_slices, block_size, dtype - ) - adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) - num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") - - max_pad = topk_ids.numel() + E * (block_size - 1) - max_pad = round_up(max_pad, block_size) - max_blocks = CEILDIV(max_pad, block_size) - sorted_token_ids = torch.empty((max_loras * max_pad,), dtype=torch.int32) - expert_ids = torch.empty((max_loras * max_blocks,), dtype=torch.int32) - num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32) - ops.moe_lora_align_block_size( - topk_ids, - token_lora_mapping, - E, - block_size, - max_loras, - max_pad, - max_blocks, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - adapter_enabled, - lora_ids, - ) - sorted_token_ids = sorted_token_ids.view(max_loras, -1) - expert_ids = expert_ids.view(max_loras, -1) - - residual = torch.randn((num_tokens, top_k, num_slices * N), dtype=dtype) * 0.1 - - # Reference: single-stream in-place add_inputs=True. - out_ref = residual.clone() - _call_one_shot( - out_ref, - hidden_states, - lora_a_stacked, - lora_b_stacked, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - token_lora_mapping, - R, - top_k, - lora_ids, - num_active_loras, - adapter_enabled, - block_size, - add_inputs=True, - ) - - # Dual-stream pattern: LoRA delta on aux stream, then join + add_. - aux_stream = torch.cuda.Stream() - event0 = torch.cuda.Event() - event1 = torch.cuda.Event() - out_dual = residual.clone() - delta = torch.zeros_like(residual) - - event0.record() - with torch.cuda.stream(aux_stream): - event0.wait() - _call_one_shot( - delta, - hidden_states, - lora_a_stacked, - lora_b_stacked, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - token_lora_mapping, - R, - top_k, - lora_ids, - num_active_loras, - adapter_enabled, - block_size, - add_inputs=False, - ) - event1.record() - event1.wait() - out_dual.add_(delta) - - torch.testing.assert_close(out_ref, out_dual, atol=1e-2, rtol=1e-2) - - -@pytest.mark.parametrize("device", DEVICES) -def test_fused_moe_lora_kernel_dual_stream_cuda_graph(device): - """The MoE-LoRA dual-stream block (LoRA delta on aux stream + .add_() - on default stream, joined via cuda.Event) must be capturable into a - torch.cuda.CUDAGraph and replayable with the same numerical result - as eager execution. This is the prerequisite for the dual-stream - path to work under vLLM's decode-time CUDA-graph capture. - """ - torch.set_default_device(device) - set_random_seed(0) - - num_tokens, top_k, E, max_loras, R, K, N = 32, 2, 8, 4, 16, 1024, 512 - block_size, num_slices, dtype = 16, 2, torch.bfloat16 - - ( - topk_ids, - topk_weights, - token_lora_mapping, - lora_ids, - lora_a_stacked, - lora_b_stacked, - hidden_states, - ) = _build_one_shot_inputs( - num_tokens, top_k, E, max_loras, R, K, N, num_slices, block_size, dtype - ) - adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) - num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") - - max_pad = topk_ids.numel() + E * (block_size - 1) - max_pad = round_up(max_pad, block_size) - max_blocks = CEILDIV(max_pad, block_size) - sorted_token_ids = torch.empty((max_loras * max_pad,), dtype=torch.int32) - expert_ids = torch.empty((max_loras * max_blocks,), dtype=torch.int32) - num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32) - ops.moe_lora_align_block_size( - topk_ids, - token_lora_mapping, - E, - block_size, - max_loras, - max_pad, - max_blocks, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - adapter_enabled, - lora_ids, - ) - sorted_token_ids = sorted_token_ids.view(max_loras, -1) - expert_ids = expert_ids.view(max_loras, -1) - - # Persistent stream + events (must outlive the graph; created in capture - # would be replayed against fresh objects each replay, defeating the - # point). - aux_stream = torch.cuda.Stream() - event0 = torch.cuda.Event() - event1 = torch.cuda.Event() - - residual = torch.randn((num_tokens, top_k, num_slices * N), dtype=dtype) * 0.1 - - def _run_dual_stream(out_buf: torch.Tensor) -> None: - # Mirrors the structure of TritonExperts.apply's w13 / w2 blocks. - delta = torch.zeros_like(out_buf) - event0.record() - with torch.cuda.stream(aux_stream): - event0.wait() - _call_one_shot( - delta, - hidden_states, - lora_a_stacked, - lora_b_stacked, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - token_lora_mapping, - R, - top_k, - lora_ids, - num_active_loras, - adapter_enabled, - block_size, - add_inputs=False, - ) - event1.record() - event1.wait() - out_buf.add_(delta) - - # Warm up: triton compile cache must be primed before capture, otherwise - # JIT compilation gets recorded into the graph (or fails capture). - warm = residual.clone() - _run_dual_stream(warm) - torch.cuda.synchronize() - - # Eager baseline - out_eager = residual.clone() - _run_dual_stream(out_eager) - torch.cuda.synchronize() - - # Captured + replay. Capture stream is separate from the default; events - # used inside must be recorded/waited on the capture stream (or aux - # stream that's dependent on it). torch.cuda.graph handles the stream - # accounting automatically. - g = torch.cuda.CUDAGraph() - out_graph = residual.clone() - capture_stream = torch.cuda.Stream() - capture_stream.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(capture_stream), torch.cuda.graph(g, stream=capture_stream): - _run_dual_stream(out_graph) - torch.cuda.current_stream().wait_stream(capture_stream) - torch.cuda.synchronize() - - # First replay: the buffer was already populated by the capture itself, - # which is the eager-style write done during stream-recording. To - # validate replay semantics, reset the output and replay. - out_graph.copy_(residual) - g.replay() - torch.cuda.synchronize() - - torch.testing.assert_close(out_graph, out_eager, atol=1e-2, rtol=1e-2) - - # Replay a second time to confirm graph state is replayable repeatedly. - out_graph.copy_(residual) - g.replay() - torch.cuda.synchronize() - torch.testing.assert_close(out_graph, out_eager, atol=1e-2, rtol=1e-2) - - -def test_moe_forward_custom_op_registered(): - """The dual-stream MoE-LoRA path in TritonExperts.apply (see - vllm/model_executor/layers/fused_moe/fused_moe.py) relies on the entire - MoE forward being reachable only through `torch.ops.vllm.moe_forward` / - `torch.ops.vllm.moe_forward_shared`, both of which are opaque custom - ops. That opacity is what makes torch.compile / Dynamo stop *before* - seeing our `torch.cuda.stream(...)` / `event.record()/wait()` calls, - so we don't have to wrap the dual-stream block in its own custom op. - - If a future refactor drops these registrations (or reroutes the MoE - forward through a non-opaque path), the dual-stream code would start - triggering Dynamo graph breaks -- or, worse, fail silently under - torch.compile. This test exists to catch that regression at the - invariant level rather than via a flaky end-to-end compile run. - """ - # Importing the module side-effect-registers the ops. - import vllm.model_executor.layers.fused_moe.runner.moe_runner # noqa: F401 - - # Both ops must exist on torch.ops.vllm. - assert hasattr(torch.ops.vllm, "moe_forward"), ( - "torch.ops.vllm.moe_forward is gone. The MoE-LoRA dual-stream " - "path assumed this wrapper made the whole MoE forward opaque to " - "Dynamo. See the NOTE block above the registration in " - "vllm/model_executor/layers/fused_moe/runner/moe_runner.py." - ) - assert hasattr(torch.ops.vllm, "moe_forward_shared"), ( - "torch.ops.vllm.moe_forward_shared is gone. Same dual-stream " - "contract -- see runner/moe_runner.py NOTE." - ) diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py index 3168d25fe477..cb65cf69504a 100644 --- a/vllm/lora/layers/base_linear.py +++ b/vllm/lora/layers/base_linear.py @@ -93,20 +93,7 @@ def _init_lora_stream_context(self) -> None: self.layer_name = self.base_layer.prefix + ".lora_linear_async" compilation_config = vllm_config.compilation_config if self.layer_name in compilation_config.static_forward_context: - # TEMP(unblock-end-to-end): Upstream FusedMoE.runner exposes the - # gate via an aliased path (mlp.gate AND mlp.experts.runner.gate - # refer to the same nn.Module), so LoRA module replacement wraps - # the same gate twice. Both wrappers compute self.layer_name - # from base_layer.prefix and thus collide here. The two - # wrappers share a base_layer, so disabling dual-stream on the - # duplicate (it falls through to _apply_sync) is safe and - # preserves correctness; only the first-registered wrapper - # keeps the overlap. TODO(remove once upstream LoRA walker - # de-duplicates by base-layer identity). - self._enable_aux_cuda_stream = False - self._lora_stream = None - self._events = [] - return + raise ValueError("Duplicate layer name: {}".format(self.layer_name)) compilation_config.static_forward_context[self.layer_name] = self def create_lora_weights( diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 1638e30b84fd..9ec3a9d4d6d6 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -344,7 +344,10 @@ def _run_fused_moe_lora_one_shot( # rank padding is to next pow2 with a floor of 16 (tensor-core minimum # K-dim). Beyond 128 the (BLOCK_M, BLOCK_R) accumulator outgrows the # register file; rank tiling would be needed but is out of scope for - # this kernel. + # this kernel. Tried floor=32 to double MMA density per K-step but it + # regressed across all M (+8 to +40%): the (64,32) fp32 accumulator + + # widened B tile pushed register count past spill threshold, lowering + # occupancy by more than the MMA gain saved. assert rank <= 128, ( f"fused_moe_lora_one_shot supports max_lora_rank<=128; got rank={rank}" ) @@ -422,6 +425,9 @@ def _run_fused_moe_lora_one_shot( # the 4-stage pipeline pushed register count to 168/thread and capped # achieved occupancy at ~17% (3 blocks/SM, register-bound); ns=3 frees # ~30 regs/thread which keeps a 4th block resident on small grids. + # Tried BLOCK_N=64 for w13 (N=192) to avoid the half-wasted second + # tile: regressed 11-29% because the "waste" was just masked stores + # (cheap) and the extra iteration added load + index overhead. if npid > 1: block_n, nw, ns = 128, 8, 3 else: From 5bd9ea9b9ad86ed1eb21e06338f0d4a5d6cdd357 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 20 May 2026 01:07:18 +0000 Subject: [PATCH 09/10] FMT Signed-off-by: Jee Jee Li --- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 2 +- vllm/lora/punica_wrapper/punica_base.py | 2 ++ vllm/lora/punica_wrapper/punica_xpu.py | 2 ++ 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 9ec3a9d4d6d6..e2fb05a41bf4 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -355,7 +355,7 @@ def _run_fused_moe_lora_one_shot( num_experts = A0.shape[1] naive = sorted_token_ids is None - if naive: + if sorted_token_ids is None: EM_grid = topk_weights.numel() BLOCK_M = 16 stride_tl_ = 0 diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 0448a6d00cda..086712991907 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -514,6 +514,7 @@ def add_lora_w13( num_slices: int, fully_sharded: bool, use_tuned_config: bool, + add_inputs: bool = True, token_lora_mapping: torch.Tensor | None = None, ) -> tuple[ torch.Tensor | None, @@ -554,6 +555,7 @@ def add_lora_w2( fully_sharded: bool, tp_rank: int, use_tuned_config: bool, + add_inputs: bool = True, ) -> None: """Apply w2 LoRA to y (intermediate_cache3) in-place before moe_sum. diff --git a/vllm/lora/punica_wrapper/punica_xpu.py b/vllm/lora/punica_wrapper/punica_xpu.py index 5b406b550bac..7fdadad09391 100755 --- a/vllm/lora/punica_wrapper/punica_xpu.py +++ b/vllm/lora/punica_wrapper/punica_xpu.py @@ -463,6 +463,7 @@ def add_lora_w13( num_slices: int, fully_sharded: bool, use_tuned_config: bool, + add_inputs: bool = True, token_lora_mapping: torch.Tensor | None = None, ) -> tuple[ torch.Tensor | None, @@ -596,6 +597,7 @@ def add_lora_w2( fully_sharded: bool, tp_rank: int, use_tuned_config: bool, + add_inputs: bool = True, ) -> None: import functools From d5754b0a995e301d6f2dab1fb0d683a70fc5460a Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Wed, 20 May 2026 15:14:01 +0000 Subject: [PATCH 10/10] Cleanup Signed-off-by: Jee Jee Li --- vllm/envs.py | 8 ------ vllm/lora/layers/fused_moe.py | 22 ++------------- vllm/lora/layers/utils.py | 4 --- vllm/lora/ops/triton_ops/fused_moe_lora_op.py | 28 ++----------------- 4 files changed, 4 insertions(+), 58 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index df454ce07550..a787edf589d9 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -274,7 +274,6 @@ VLLM_XPU_ENABLE_XPU_GRAPH: bool = False VLLM_XPU_USE_SAMPLER_KERNEL: bool = True VLLM_LORA_ENABLE_DUAL_STREAM: bool = False - VLLM_LORA_USE_ONE_SHOT_MOE: bool = True def get_default_cache_root(): @@ -1822,13 +1821,6 @@ def _get_or_set_default() -> str: "VLLM_LORA_ENABLE_DUAL_STREAM": lambda: bool( int(os.getenv("VLLM_LORA_ENABLE_DUAL_STREAM", "0")) ), - # Whether to use the one-shot fused MoE LoRA kernel (combined shrink+expand). - # When disabled, falls back to the legacy two-kernel shrink/expand path. - # Dual-stream MoE LoRA depends on the one-shot kernel's add_inputs=False - # contract, so dual-stream is force-disabled when this is off. - "VLLM_LORA_USE_ONE_SHOT_MOE": lambda: bool( - int(os.getenv("VLLM_LORA_USE_ONE_SHOT_MOE", "1")) - ), # If set to 1, use Python spinloop extension to poll in a more efficient # way when using the mp backend. "VLLM_USE_SPINLOOP_EXT": lambda: bool(int(os.getenv("VLLM_USE_SPINLOOP_EXT", "0"))), diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 4660e6c9d902..46ec26334159 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -35,16 +35,8 @@ def __init__(self, base_layer: FusedMoE) -> None: self.tp_size = self.base_layer.tp_size self.tp_rank = self.base_layer.tp_rank self.device = _get_lora_device(base_layer) - # Reuses VLLM_LORA_ENABLE_DUAL_STREAM (the same env that controls - # the linear-LoRA dual-stream path in - # vllm/lora/layers/base_linear.py); enabling it for a deployment - # turns dual-stream on for both linear and MoE LoRA layers in one - # switch. Dual-stream relies on the one-shot kernel's - # add_inputs=False contract, so it is force-disabled when the - # one-shot path is turned off via VLLM_LORA_USE_ONE_SHOT_MOE=0. - self._enable_aux_cuda_stream = ( - envs.VLLM_LORA_ENABLE_DUAL_STREAM and envs.VLLM_LORA_USE_ONE_SHOT_MOE - ) + + self._enable_aux_cuda_stream = envs.VLLM_LORA_ENABLE_DUAL_STREAM self._init_lora_stream_context() # For non-gated MoE (is_act_and_mul=False), only 1 slice is needed # since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3) @@ -78,12 +70,6 @@ def __init__(self, base_layer: FusedMoE) -> None: ) def _init_lora_stream_context(self) -> None: - # Dual-stream is incompatible with fully_sharded MoE LoRA: that path - # routes through the legacy two-kernel flow with an embedded - # all_reduce / all_gather between shrink and expand, where the - # add_inputs=False contract is not wired (see _fused_moe_lora). - # When fully_sharded is enabled at LoRA-config time, we silently - # disable the dual-stream path here so the env var still works. self._lora_stream: torch.cuda.Stream | None = None self._events: tuple[torch.cuda.Event, ...] | None = None if not self._enable_aux_cuda_stream: @@ -97,10 +83,6 @@ def _init_lora_stream_context(self) -> None: self._events = tuple(torch.cuda.Event() for _ in range(4)) def _build_lora_context(self): - # Hand the stream/events to the experts only when fully_sharded is - # off (the path the one-shot kernel + add_inputs=False contract - # supports). For fully_sharded we leave aux_stream=None so - # experts.apply() takes the original sequential schedule. use_dual_stream = ( self._enable_aux_cuda_stream and not self.fully_sharded diff --git a/vllm/lora/layers/utils.py b/vllm/lora/layers/utils.py index 40d46ac9d977..cb2054fb5f0b 100644 --- a/vllm/lora/layers/utils.py +++ b/vllm/lora/layers/utils.py @@ -16,10 +16,6 @@ def _get_lora_aux_cuda_stream() -> torch.cuda.Stream | None: - # Gate stream creation on the dual-stream master switch so a stray call - # from a future code path cannot silently allocate a CUDA stream when the - # feature is turned off. MoE LoRA layers an additional VLLM_LORA_USE_ONE_SHOT_MOE - # gate at their call site. if not envs.VLLM_LORA_ENABLE_DUAL_STREAM: return None global _lora_aux_cuda_stream diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index e2fb05a41bf4..fbad39f43a16 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -3,7 +3,6 @@ import torch -from vllm import envs from vllm.distributed import ( tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, @@ -1467,26 +1466,8 @@ def _fused_moe_lora( assert output.shape[0] == topk_weights.shape[0] assert top_k_num == topk_weights.shape[1] - # Fast path: single fused kernel keeps the rank-dim intermediate in - # registers and avoids the HBM round-trip of the legacy two-kernel - # implementation. fully_sharded=True still needs the materialised - # intermediate cache so that an all_reduce / all_gather can flow - # between shrink and expand, so it falls through to the legacy path. - # VLLM_LORA_USE_ONE_SHOT_MOE=0 also forces the legacy path for - # debugging / benchmarking. - if not fully_sharded and envs.VLLM_LORA_USE_ONE_SHOT_MOE: - # Inside the one_shot fast path we further split between two - # kernels: - # * small-batch persistent GEMV — when the caller picked - # naive_block_assignment (sorted_token_ids is None — happens - # whenever num_tokens*top_k is sparse vs num_experts*max_loras, - # see SPARSITY_FACTOR in punica_gpu.add_lora_fused_moe), AND - # M_pairs * rank ≤ 1024 (cutoff from a GB200 sweep over ranks - # {16,32,64} — below this, the persistent GEMV path is - # 1.0-1.7x faster than the one_shot GEMM tile kernel). - # * one_shot GEMM tile kernel — everything else (prefill / large - # batch). Both are "fused" in that shrink+expand stay in - # registers; they differ only in tiling strategy. + # Fast path: single fused kernel + if not fully_sharded: M_pairs = topk_weights.numel() if ( sorted_token_ids is None @@ -1532,11 +1513,6 @@ def _fused_moe_lora( ) return - # The legacy two-kernel path keeps the historical in-place semantics -- - # `_fused_moe_lora_expand` always sets `ADD_INPUTS=True` so the rank-dim - # cache flowing through all_reduce/all_gather stays consistent. The - # add_inputs=False contract is only wired for the one-shot fast path - # above, so reject it here rather than silently writing wrong results. assert add_inputs, ( "fused_moe_lora(add_inputs=False) is only supported on the " "fully_sharded=False fast path"