diff --git a/benchmarks/kernels/benchmark_fused_moe_lora_one_shot.py b/benchmarks/kernels/benchmark_fused_moe_lora_one_shot.py new file mode 100644 index 000000000000..c5a592ea7e2b --- /dev/null +++ b/benchmarks/kernels/benchmark_fused_moe_lora_one_shot.py @@ -0,0 +1,465 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Benchmark the fused MoE-LoRA fast path (one-shot) vs two-kernel baseline. + +The "one_shot" provider goes through `vllm.lora.ops.triton_ops.fused_moe_lora` +which dispatches to the single-kernel one-shot implementation when +fully_sharded=False (the prefill default). + +The "two_kernel" provider drives `fused_moe_lora_shrink` + `fused_moe_lora_expand` +directly, bypassing the dispatch and matching the legacy two-kernel path's +work distribution. This isolates the win from kernel fusion. + +Run: + .venv/bin/python -m benchmarks.kernels.benchmark_fused_moe_lora_one_shot + .venv/bin/python -m benchmarks.kernels.benchmark_fused_moe_lora_one_shot \\ + --model qwen3moe +""" + +from __future__ import annotations + +import argparse +import os +import random + +import torch + +from vllm import _custom_ops as ops +from vllm.lora.ops.triton_ops import ( + fused_moe_lora, + fused_moe_lora_expand, + fused_moe_lora_shrink, +) +from vllm.triton_utils import triton + +DTYPE = torch.bfloat16 +DEVICE = "cuda" + + +# ----- input fabrication ----------------------------------------------------- + + +def _round_up(x: int, base: int) -> int: + return ((x + base - 1) // base) * base + + +def _ceildiv(x: int, y: int) -> int: + return (x + y - 1) // y + + +def _assign_loras(num_tokens: int, num_sequences: int, max_loras: int) -> torch.Tensor: + tokens_per_seq = num_tokens // num_sequences + rem = num_tokens % num_sequences + out = torch.empty(num_tokens, dtype=torch.int32) + start = 0 + for i in range(num_sequences): + end = start + tokens_per_seq + (1 if i < rem else 0) + out[start:end] = random.randint(0, max_loras - 1) + start = end + return out + + +def _assign_experts(num_tokens: int, num_experts: int, top_k: int): + expert_indices = torch.empty((num_tokens, top_k), dtype=torch.int32) + for i in range(num_tokens): + expert_indices[i] = torch.randperm(num_experts)[:top_k] + weights = torch.rand((num_tokens, top_k), dtype=torch.float32) + weights = weights / weights.sum(dim=1, keepdim=True) + return expert_indices, weights + + +def _make_inputs( + M: int, + K: int, + N_per_slice: int, + rank: int, + num_experts: int, + top_k: int, + max_loras: int, + num_slices: int, + block_size_m: int, +): + """Mirrors the production caller's tensor layout.""" + torch.manual_seed(0) + random.seed(0) + + num_sequences = max(1, min(M, 8)) + topk_ids_cpu, topk_weights_cpu = _assign_experts(M, num_experts, top_k) + token_lora_cpu = _assign_loras(M, num_sequences, max_loras) + lora_ids_cpu = torch.full((max_loras + 1,), -1, dtype=torch.int32) + uniq = torch.unique(token_lora_cpu, sorted=True) + lora_ids_cpu[: uniq.size(0)].copy_(uniq) + + topk_ids = topk_ids_cpu.to(DEVICE) + topk_weights = topk_weights_cpu.to(device=DEVICE, dtype=DTYPE) + token_lora_mapping = token_lora_cpu.to(DEVICE) + lora_ids = lora_ids_cpu.to(DEVICE) + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32, device=DEVICE) + + lora_a = [ + torch.randn((max_loras, num_experts, rank, K), dtype=DTYPE, device=DEVICE) + / max(K, 1) ** 0.5 + for _ in range(num_slices) + ] + lora_b = [ + torch.randn( + (max_loras, num_experts, N_per_slice, rank), + dtype=DTYPE, + device=DEVICE, + ) + / max(rank, 1) ** 0.5 + for _ in range(num_slices) + ] + hidden = torch.randn((M, K), dtype=DTYPE, device=DEVICE) + out_template = torch.zeros( + (M, top_k, num_slices * N_per_slice), dtype=DTYPE, device=DEVICE + ) + + # Sorted-path metadata (the prefill default). + max_pad = topk_ids.numel() + num_experts * (block_size_m - 1) + max_pad = _round_up(max_pad, block_size_m) + max_blocks = _ceildiv(max_pad, block_size_m) + sorted_token_ids = torch.empty( + (max_loras * max_pad,), dtype=torch.int32, device=DEVICE + ) + expert_ids = torch.empty( + (max_loras * max_blocks,), dtype=torch.int32, device=DEVICE + ) + num_post = torch.empty((max_loras,), dtype=torch.int32, device=DEVICE) + ops.moe_lora_align_block_size( + topk_ids, + token_lora_mapping, + num_experts, + block_size_m, + max_loras, + max_pad, + max_blocks, + sorted_token_ids, + expert_ids, + num_post, + adapter_enabled, + lora_ids, + ) + expert_ids = expert_ids.view(max_loras, -1).contiguous() + sorted_token_ids = sorted_token_ids.view(max_loras, -1).contiguous() + num_active = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") + + return dict( + hidden=hidden, + lora_a=lora_a, + lora_b=lora_b, + topk_weights=topk_weights, + sorted_token_ids=sorted_token_ids, + expert_ids=expert_ids, + num_post=num_post, + token_lora_mapping=token_lora_mapping, + lora_ids=lora_ids, + num_active=num_active, + adapter_enabled=adapter_enabled, + out_template=out_template, + # bookkeeping + M=M, + K=K, + N_per_slice=N_per_slice, + rank=rank, + num_experts=num_experts, + top_k=top_k, + max_loras=max_loras, + num_slices=num_slices, + block_size_m=block_size_m, + ) + + +# ----- providers ------------------------------------------------------------- + + +def _run_one_shot(inp: dict): + """Drive `fused_moe_lora` with fully_sharded=False -> one-shot fast path.""" + out = inp["out_template"].clone() + fused_moe_lora( + out, + inp["hidden"], + inp["lora_a"], + inp["lora_b"], + inp["topk_weights"], + inp["sorted_token_ids"], + inp["expert_ids"], + inp["num_post"], + inp["token_lora_mapping"], + inp["rank"], + inp["top_k"], + inp["lora_ids"], + inp["num_active"], + inp["adapter_enabled"], + inp["block_size_m"], + 64, + 32, + 8, + 4, + 3, + 1, + inp["block_size_m"], + 64, + 32, + 8, + 4, + 3, + 1, + False, + False, + 0, + ) + return out + + +def _run_two_kernel(inp: dict): + """Drive `fused_moe_lora_shrink` + `fused_moe_lora_expand` directly, + bypassing the dispatch. Matches the legacy two-kernel work distribution. + """ + M = inp["M"] + top_k = inp["top_k"] + rank = inp["rank"] + num_slices = inp["num_slices"] + N_per_slice = inp["N_per_slice"] + K = inp["K"] + num_experts = inp["num_experts"] + block_m = inp["block_size_m"] + + intermediate = torch.zeros((num_slices, M, top_k, rank), dtype=DTYPE, device=DEVICE) + out = inp["out_template"].clone() + EM = inp["sorted_token_ids"].shape[1] + num_tokens = M * top_k + + fused_moe_lora_shrink( + intermediate, + inp["hidden"], + inp["lora_a"], + inp["topk_weights"], + inp["sorted_token_ids"], + inp["expert_ids"], + inp["num_post"], + inp["token_lora_mapping"], + top_k, + inp["lora_ids"], + inp["adapter_enabled"], + torch.device(DEVICE), + rank, + M, + EM, + K, + num_tokens, + num_experts, + num_slices, + block_m, + 64, + 32, + 8, + 4, + 3, + 1, + inp["num_active"], + False, + ) + fused_moe_lora_expand( + out, + intermediate, + inp["lora_b"], + inp["topk_weights"], + inp["sorted_token_ids"], + inp["expert_ids"], + inp["num_post"], + inp["token_lora_mapping"], + top_k, + inp["lora_ids"], + inp["adapter_enabled"], + torch.device(DEVICE), + rank, + M, + EM, + K, + num_tokens, + num_experts, + num_slices, + rank, + N_per_slice, + block_m, + 64, + 32, + 8, + 4, + 3, + 1, + inp["num_active"], + False, + 0, + ) + return out + + +PROVIDER_FNS = { + "one_shot": _run_one_shot, + "two_kernel": _run_two_kernel, +} + + +# ----- model presets --------------------------------------------------------- + + +MODEL_PRESETS: dict[str, dict] = { + # Mixtral-8x7B style: E=8, top_k=2, hidden=4096, intermediate=14336 + "mixtral": dict( + K=4096, + N_per_slice=7168, + num_experts=8, + top_k=2, + max_loras=4, + num_slices=2, + block_size_m=64, + ), + # Qwen3-MoE / DeepSeek-V2 style: E=64, top_k=8, hidden=2048, inter=1408 + "qwen3moe": dict( + K=2048, + N_per_slice=1408, + num_experts=64, + top_k=8, + max_loras=4, + num_slices=2, + block_size_m=64, + ), + # GLM-5.1 (zai-org/GLM-5.1-FP8): E=256, top_k=8, hidden=6144, + # moe_intermediate=2048 + "glm5_1": dict( + K=6144, + N_per_slice=2048, + num_experts=256, + top_k=8, + max_loras=4, + num_slices=2, + block_size_m=64, + ), +} + + +M_RANGE = [16, 64, 256, 1024, 4096, 16384] +RANK_RANGE = [8, 16, 32, 64] + + +def get_benchmark(model: str, max_loras: int | None = None): + preset = dict(MODEL_PRESETS[model]) + if max_loras is not None: + preset["max_loras"] = max_loras + + @triton.testing.perf_report( + triton.testing.Benchmark( + x_names=["M", "rank"], + x_vals=[(M, R) for M in M_RANGE for R in RANK_RANGE], + line_arg="provider", + line_vals=list(PROVIDER_FNS.keys()), + line_names=["one_shot (fused)", "two_kernel (legacy)"], + styles=[("red", "-"), ("blue", "-")], + ylabel="ms", + plot_name=f"fused_moe_lora-{model}-loras{preset['max_loras']}", + args={"preset": preset}, + ) + ) + def benchmark(M, rank, provider, preset): + inp = _make_inputs( + M=M, + K=preset["K"], + N_per_slice=preset["N_per_slice"], + rank=rank, + num_experts=preset["num_experts"], + top_k=preset["top_k"], + max_loras=preset["max_loras"], + num_slices=preset["num_slices"], + block_size_m=preset["block_size_m"], + ) + fn = PROVIDER_FNS[provider] + quantiles = [0.5, 0.2, 0.8] + ms, min_ms, max_ms = triton.testing.do_bench( + lambda: fn(inp), quantiles=quantiles + ) + return ms, max_ms, min_ms + + return benchmark + + +# ----- correctness sanity --------------------------------------------------- + + +def calculate_diff(model: str, M: int, rank: int, max_loras: int | None = None): + preset = dict(MODEL_PRESETS[model]) + if max_loras is not None: + preset["max_loras"] = max_loras + inp = _make_inputs( + M=M, + K=preset["K"], + N_per_slice=preset["N_per_slice"], + rank=rank, + num_experts=preset["num_experts"], + top_k=preset["top_k"], + max_loras=preset["max_loras"], + num_slices=preset["num_slices"], + block_size_m=preset["block_size_m"], + ) + out_one = _run_one_shot(inp) + out_two = _run_two_kernel(inp) + max_abs = (out_one.float() - out_two.float()).abs().max().item() + print( + f" model={model:<9} M={M:<6} rank={rank:<3} " + f"max|one_shot - two_kernel|={max_abs:.4g} " + f"ref|max|={out_two.float().abs().max().item():.3g}" + ) + if max_abs <= 5e-2: + print(" ✅ outputs match within bf16 tolerance") + else: + print(" ❌ outputs differ beyond expected bf16 noise") + + +# ----- main ------------------------------------------------------------------ + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--model", + type=str, + default="mixtral", + choices=list(MODEL_PRESETS.keys()), + help="Model preset to sweep", + ) + parser.add_argument( + "--save-path", + type=str, + default="./configs/fused_moe_lora_one_shot/", + help="Directory to save benchmark results", + ) + parser.add_argument( + "--check-only", + action="store_true", + help="Run correctness sanity check only, no perf sweep", + ) + parser.add_argument( + "--max-loras", + type=int, + default=None, + help="Override max_loras in the model preset (number of LoRA adapters " + "active in the batch). Defaults to the preset's value.", + ) + args = parser.parse_args() + + print(f"Correctness check ({args.model}):") + calculate_diff(args.model, M=256, rank=32, max_loras=args.max_loras) + if args.check_only: + raise SystemExit(0) + + effective_max_loras = ( + args.max_loras + if args.max_loras is not None + else MODEL_PRESETS[args.model]["max_loras"] + ) + print(f"\nGPU: {torch.cuda.get_device_name()}") + print(f"Model preset: {args.model} max_loras={effective_max_loras}\n") + benchmark = get_benchmark(args.model, max_loras=args.max_loras) + os.makedirs(args.save_path, exist_ok=True) + benchmark.run(print_data=True, save_path=args.save_path) diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index e50d7d5aacfe..a70c5434736f 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -141,6 +141,7 @@ def use_fused_moe_lora_kernel( block_size, fully_sharded=False, offset=0, + add_inputs=True, ): max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) @@ -222,6 +223,7 @@ def use_fused_moe_lora_kernel( mul_routed_weight, fully_sharded=fully_sharded, offset=offset, + add_inputs=add_inputs, ) @@ -371,6 +373,7 @@ def use_fused_moe_lora_kernel_naive( block_size, fully_sharded=False, offset=0, + add_inputs=True, ): """ Test helper for naive_block_assignment path. @@ -435,6 +438,7 @@ def use_fused_moe_lora_kernel_naive( mul_routed_weight=mul_routed_weight, fully_sharded=fully_sharded, offset=offset, + add_inputs=add_inputs, ) @@ -745,3 +749,494 @@ def _get_shard_slice(shard_size): output = tensor_model_parallel_all_reduce(output) torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2) + + +# -- one-shot fast-path coverage -------------------------------------------- +# The fused shrink+expand one-shot kernel pads `BLOCK_R` to next_pow2(rank), +# with a floor of 16 (tensor-core minimum). Small ranks (4, 8) exercise the +# rank-dim masking and are not covered by the original tests, which start at +# rank=16. The legacy two-kernel path additionally fails on rank=4 in TMA +# mode because the rank-dim stride (rank * elem_size) is not 16-byte +# aligned; the one-shot fast path takes precedence whenever fully_sharded +# is False so this regression is hidden in normal use, but the test still +# ensures the one-shot logic is correct against the pytorch reference. + + +@pytest.mark.parametrize("num_tokens", [16, 100]) +@pytest.mark.parametrize("top_k_num", [2]) +@pytest.mark.parametrize("num_experts", [8, 64]) +@pytest.mark.parametrize("max_loras", [4]) +@pytest.mark.parametrize("N", [1408]) +@pytest.mark.parametrize("K", [2048]) +@pytest.mark.parametrize("max_lora_rank", [4, 8]) +@pytest.mark.parametrize("block_size", [16, 64]) +@pytest.mark.parametrize("num_slices", [1, 2]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +def test_fused_moe_lora_kernel_small_rank( + num_tokens, + top_k_num, + num_experts, + max_loras, + N, + K, + max_lora_rank, + block_size, + num_slices, + dtype, + device, + seed, +): + """One-shot fast path covering rank<16 (padded to BLOCK_R=16 inside kernel).""" + torch.set_default_device(device) + set_random_seed(seed) + num_sequences = max(1, min(num_tokens, 8)) + topk_ids, topk_weights, token_lora_mapping, lora_ids = sample_data( + num_tokens, num_sequences, max_loras, num_experts, top_k_num + ) + + lora_a_stacked = [ + torch.rand( + (max_loras, num_experts, max_lora_rank, K), + dtype=dtype, + ) + for _ in range(num_slices) + ] + lora_b_stacked = [ + torch.rand( + (max_loras, num_experts, N // num_slices, max_lora_rank), + dtype=dtype, + ) + for _ in range(num_slices) + ] + hidden_states = torch.rand((num_tokens, K), dtype=dtype) + + output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype) + use_fused_moe_lora_kernel( + topk_ids, + topk_weights, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_ids, + lora_a_stacked, + lora_b_stacked, + hidden_states, + output, + max_loras, + num_experts, + block_size, + ) + output_ref = use_torch( + hidden_states, + token_lora_mapping, + topk_ids, + lora_a_stacked, + lora_b_stacked, + top_k_num, + num_slices, + ) + torch.testing.assert_close(output, output_ref, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize("num_tokens", [16, 64]) +@pytest.mark.parametrize("top_k_num", [2]) +@pytest.mark.parametrize("num_experts", [8]) +@pytest.mark.parametrize("max_loras", [4]) +@pytest.mark.parametrize("N", [2048]) +@pytest.mark.parametrize("K", [4096]) +@pytest.mark.parametrize("max_lora_rank", [8, 16, 32, 64]) +@pytest.mark.parametrize("block_size", [64]) +@pytest.mark.parametrize("num_slices", [2]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +def test_fused_moe_lora_kernel_npid_path( + num_tokens, + top_k_num, + num_experts, + max_loras, + N, + K, + max_lora_rank, + block_size, + num_slices, + dtype, + device, + seed, +): + """Exercise the small-batch / NPID > 1 branch of the one-shot fast path. + + With these sizes the one-shot wrapper computes NPID_FACTOR > 1 (base CTA count + < SM count), so each program covers only an outer chunk of N. The + cross-outer-block write mask is the correctness-critical bit. + """ + torch.set_default_device(device) + set_random_seed(seed) + num_sequences = max(1, min(num_tokens, 4)) + topk_ids, topk_weights, token_lora_mapping, lora_ids = sample_data( + num_tokens, num_sequences, max_loras, num_experts, top_k_num + ) + + lora_a_stacked = [ + torch.rand( + (max_loras, num_experts, max_lora_rank, K), + dtype=dtype, + ) + for _ in range(num_slices) + ] + lora_b_stacked = [ + torch.rand( + (max_loras, num_experts, N // num_slices, max_lora_rank), + dtype=dtype, + ) + for _ in range(num_slices) + ] + hidden_states = torch.rand((num_tokens, K), dtype=dtype) + + output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype) + use_fused_moe_lora_kernel( + topk_ids, + topk_weights, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_ids, + lora_a_stacked, + lora_b_stacked, + hidden_states, + output, + max_loras, + num_experts, + block_size, + ) + output_ref = use_torch( + hidden_states, + token_lora_mapping, + topk_ids, + lora_a_stacked, + lora_b_stacked, + top_k_num, + num_slices, + ) + torch.testing.assert_close(output, output_ref, atol=2e-2, rtol=2e-2) + + +# -- one-shot corner-case coverage ------------------------------------------ +# Each of the following exercises a path where the kernel is launched but +# every program early-exits, leaving the output unchanged. The contract is +# additive (`output += contribution`), so an empty contribution must leave +# the input residual untouched. + + +def _build_one_shot_inputs( + num_tokens, + top_k_num, + num_experts, + max_loras, + max_lora_rank, + K, + N, + num_slices, + block_size, + dtype, +): + """Common scaffolding for the corner-case tests below.""" + num_sequences = max(1, min(num_tokens, 4)) if num_tokens > 0 else 1 + if num_tokens > 0: + topk_ids, topk_weights, token_lora_mapping, lora_ids = sample_data( + num_tokens, num_sequences, max_loras, num_experts, top_k_num + ) + else: + # M=0 path: caller may still hand us empty tensors with the right shape. + topk_ids = torch.empty((0, top_k_num), dtype=torch.int32) + topk_weights = torch.empty((0, top_k_num), dtype=torch.float32) + token_lora_mapping = torch.empty((0,), dtype=torch.int32) + lora_ids = torch.full((max_loras + 1,), -1, dtype=torch.int32) + + lora_a_stacked = [ + torch.rand((max_loras, num_experts, max_lora_rank, K), dtype=dtype) + for _ in range(num_slices) + ] + lora_b_stacked = [ + torch.rand( + (max_loras, num_experts, N // num_slices, max_lora_rank), dtype=dtype + ) + for _ in range(num_slices) + ] + hidden_states = torch.rand((max(num_tokens, 0), K), dtype=dtype) + return ( + topk_ids, + topk_weights.to(dtype), + token_lora_mapping, + lora_ids, + lora_a_stacked, + lora_b_stacked, + hidden_states, + ) + + +def _call_one_shot( + output, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_ids, + num_active_loras, + adapter_enabled, + block_size, + add_inputs=True, +): + """Direct call into fused_moe_lora with one-shot-routed defaults.""" + from vllm.lora.ops.triton_ops import fused_moe_lora as _op + + _op( + output, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_ids, + num_active_loras, + adapter_enabled, + block_size, + 32, + 64, + 1, + 4, + 3, + 1, + block_size, + 32, + 64, + 1, + 4, + 3, + 1, + False, + False, + 0, + add_inputs, + ) + + +@pytest.mark.parametrize( + "trigger", + ["sorted_lora_ids_neg", "naive_mapping_neg", "naive_all_disabled"], +) +@pytest.mark.parametrize("device", DEVICES) +def test_fused_moe_lora_kernel_one_shot_early_exit(trigger, device): + """one-shot must leave the residual byte-identical when every program + must early-exit. Three trigger conditions are covered: + + - "sorted_lora_ids_neg": sorted path, lora_ids all -1 (lora_id<0 check) + - "naive_mapping_neg": naive path, token_lora_mapping all -1 + - "naive_all_disabled": naive path, adapter_enabled all 0 + """ + torch.set_default_device(device) + set_random_seed(0) + + # Per-trigger shapes: naive_mapping_neg needs the naive dispatch gate + # `num_tokens*top_k*8 <= num_experts*max_loras` to hold, hence the + # larger E/max_loras and smaller num_tokens. + if trigger == "naive_mapping_neg": + num_tokens, top_k, E, max_loras, R = 8, 2, 64, 8, 16 + elif trigger == "naive_all_disabled": + num_tokens, top_k, E, max_loras, R = 32, 2, 8, 4, 32 + else: # sorted_lora_ids_neg + num_tokens, top_k, E, max_loras, R = 32, 2, 8, 4, 16 + K, N = 1024, 1024 + block_size, num_slices, dtype = 16, 2, torch.bfloat16 + + ( + topk_ids, + topk_weights, + token_lora_mapping, + lora_ids, + lora_a_stacked, + lora_b_stacked, + hidden_states, + ) = _build_one_shot_inputs( + num_tokens, top_k, E, max_loras, R, K, N, num_slices, block_size, dtype + ) + + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) + num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") + + if trigger == "sorted_lora_ids_neg": + lora_ids = torch.full((max_loras + 1,), -1, dtype=torch.int32) + max_pad = topk_ids.numel() + E * (block_size - 1) + max_pad = round_up(max_pad, block_size) + max_blocks = CEILDIV(max_pad, block_size) + sorted_token_ids = torch.zeros((max_loras, max_pad), dtype=torch.int32) + expert_ids = torch.full((max_loras, max_blocks), -1, dtype=torch.int32) + num_post = torch.zeros((max_loras,), dtype=torch.int32) + else: + sorted_token_ids = None + expert_ids = topk_ids.reshape(-1).contiguous() + num_post = None + if trigger == "naive_mapping_neg": + token_lora_mapping = torch.full((num_tokens,), -1, dtype=torch.int32) + lora_ids = torch.full((max_loras + 1,), -1, dtype=torch.int32) + else: # naive_all_disabled + adapter_enabled = torch.zeros(max_loras + 1, dtype=torch.int32) + + residual = torch.randn((num_tokens, top_k, N), dtype=dtype) * 0.1 + output = residual.clone() + + _call_one_shot( + output, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_post, + token_lora_mapping, + R, + top_k, + lora_ids, + num_active_loras, + adapter_enabled, + block_size, + ) + torch.testing.assert_close(output, residual, atol=0, rtol=0) + + +@pytest.mark.parametrize("device", DEVICES) +def test_fused_moe_lora_kernel_zero_grid_no_crash(device): + """num_active_loras=0 (or num_slices=0) would otherwise launch a grid + with a zero dimension. one-shot wrapper must short-circuit before launch.""" + torch.set_default_device(device) + set_random_seed(0) + num_tokens, top_k, E, max_loras, R, K, N = 8, 2, 8, 4, 16, 1024, 1024 + block_size, num_slices, dtype = 16, 2, torch.bfloat16 + + ( + topk_ids, + topk_weights, + token_lora_mapping, + lora_ids, + lora_a_stacked, + lora_b_stacked, + hidden_states, + ) = _build_one_shot_inputs( + num_tokens, + top_k, + E, + max_loras, + R, + K, + N, + num_slices, + block_size, + dtype, + ) + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) + num_active_loras = torch.tensor([0], dtype=torch.int32, device="cpu") + residual = torch.randn((num_tokens, top_k, N), dtype=dtype) * 0.1 + output = residual.clone() + + # sorted path is the one that uses num_active_loras for grid axis 2 + max_pad = topk_ids.numel() + E * (block_size - 1) + max_pad = round_up(max_pad, block_size) + max_blocks = CEILDIV(max_pad, block_size) + sorted_token_ids = torch.zeros((max_loras, max_pad), dtype=torch.int32) + expert_ids = torch.full((max_loras, max_blocks), -1, dtype=torch.int32) + num_post = torch.zeros((max_loras,), dtype=torch.int32) + _call_one_shot( + output, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_post, + token_lora_mapping, + R, + top_k, + lora_ids, + num_active_loras, + adapter_enabled, + block_size, + ) + torch.testing.assert_close(output, residual, atol=0, rtol=0) + + +@pytest.mark.parametrize("device", DEVICES) +def test_fused_moe_lora_kernel_rejects_bad_block_size_m(device): + """one-shot must surface a clear assertion when shrink_block_size_m is not + a power of 2 / less than 16, instead of the cryptic Triton compile + failure (`arange's range must be a power of 2`).""" + torch.set_default_device(device) + set_random_seed(0) + num_tokens, top_k, E, max_loras, R, K, N = 32, 2, 8, 4, 16, 1024, 1024 + num_slices, dtype = 2, torch.bfloat16 + block_size = 24 # NOT a power of 2 + + ( + topk_ids, + topk_weights, + token_lora_mapping, + lora_ids, + lora_a_stacked, + lora_b_stacked, + hidden_states, + ) = _build_one_shot_inputs( + num_tokens, + top_k, + E, + max_loras, + R, + K, + N, + num_slices, + 16, + dtype, + ) + # Build sorted-mode metadata at block_size=16 so shapes are sane, + # but pass block_size=24 to the op (the buggy combination). + max_pad = topk_ids.numel() + E * (16 - 1) + max_pad = round_up(max_pad, 16) + max_blocks = CEILDIV(max_pad, 16) + sorted_token_ids = torch.zeros((max_loras, max_pad), dtype=torch.int32) + expert_ids = torch.full((max_loras, max_blocks), -1, dtype=torch.int32) + num_post = torch.zeros((max_loras,), dtype=torch.int32) + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) + num_active_loras = torch.tensor([max_loras + 1], dtype=torch.int32, device="cpu") + output = torch.zeros((num_tokens, top_k, N), dtype=dtype) + + with pytest.raises(AssertionError, match="shrink_block_size_m"): + _call_one_shot( + output, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_post, + token_lora_mapping, + R, + top_k, + lora_ids, + num_active_loras, + adapter_enabled, + block_size, + ) diff --git a/vllm/envs.py b/vllm/envs.py index f5b2759e9934..b2c5f22567fa 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1971,6 +1971,8 @@ def _resolve_rust_frontend_path() -> str | None: int(os.getenv("VLLM_USE_SIMPLE_KV_OFFLOAD", "0")) ), # Whether to enable dual cuda streams for LoRA computation + # (used by both BaseLinearLayerWithLoRA and FusedMoEWithLoRA to + # overlap the base layer compute with the LoRA fast path). "VLLM_LORA_ENABLE_DUAL_STREAM": lambda: bool( int(os.getenv("VLLM_LORA_ENABLE_DUAL_STREAM", "0")) ), diff --git a/vllm/lora/layers/base_linear.py b/vllm/lora/layers/base_linear.py index 68783ae50d4b..cb65cf69504a 100644 --- a/vllm/lora/layers/base_linear.py +++ b/vllm/lora/layers/base_linear.py @@ -25,16 +25,9 @@ from vllm.utils.torch_utils import direct_register_custom_op from .base import BaseLayerWithLoRA -from .utils import _get_lora_device +from .utils import _get_lora_aux_cuda_stream, _get_lora_device if envs.VLLM_LORA_ENABLE_DUAL_STREAM: - _lora_aux_cuda_stream: torch.cuda.Stream | None = None - - def _get_lora_aux_cuda_stream() -> torch.cuda.Stream | None: - global _lora_aux_cuda_stream - if _lora_aux_cuda_stream is None and current_platform.is_cuda_alike(): - _lora_aux_cuda_stream = torch.cuda.Stream() - return _lora_aux_cuda_stream def lora_linear_async( layer_name: str, diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index dc83143a751a..46ec26334159 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -19,8 +19,9 @@ from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoDPEPModular, ) +from vllm.platforms import current_platform -from .utils import _get_lora_device +from .utils import _get_lora_aux_cuda_stream, _get_lora_device class FusedMoEWithLoRA(BaseLayerWithLoRA): @@ -34,6 +35,9 @@ def __init__(self, base_layer: FusedMoE) -> None: self.tp_size = self.base_layer.tp_size self.tp_rank = self.base_layer.tp_rank self.device = _get_lora_device(base_layer) + + self._enable_aux_cuda_stream = envs.VLLM_LORA_ENABLE_DUAL_STREAM + self._init_lora_stream_context() # For non-gated MoE (is_act_and_mul=False), only 1 slice is needed # since there's only up_proj (w1), not gate_proj + up_proj (w1 + w3) self._w13_slices = 2 if base_layer.moe_config.is_act_and_mul else 1 @@ -65,7 +69,25 @@ def __init__(self, base_layer: FusedMoE) -> None: FusedMoEModularMethod(self.base_layer.quant_method, moe_kernel) ) + def _init_lora_stream_context(self) -> None: + self._lora_stream: torch.cuda.Stream | None = None + self._events: tuple[torch.cuda.Event, ...] | None = None + if not self._enable_aux_cuda_stream: + return + if not current_platform.is_cuda_alike(): + return + self._lora_stream = _get_lora_aux_cuda_stream() + # 4 events: 2 per (base GEMM, LoRA) pair so w13 and w2 don't reuse + # the same event objects; reuse-within-a-pair is fine because the + # second pair starts only after intermediate_cache1.add_() has joined. + self._events = tuple(torch.cuda.Event() for _ in range(4)) + def _build_lora_context(self): + use_dual_stream = ( + self._enable_aux_cuda_stream + and not self.fully_sharded + and self._lora_stream is not None + ) return MoELoRAContext( w13_lora_a_stacked=self.w13_lora_a_stacked, w13_lora_b_stacked=self.w13_lora_b_stacked, @@ -81,6 +103,8 @@ def _build_lora_context(self): local_num_experts=self.base_layer.local_num_experts, punica_wrapper=self.punica_wrapper, use_tuned_config=bool(envs.VLLM_TUNED_CONFIG_FOLDER), + aux_stream=self._lora_stream if use_dual_stream else None, + events=self._events if use_dual_stream else None, ) def _create_lora_a_weights( diff --git a/vllm/lora/layers/utils.py b/vllm/lora/layers/utils.py index 1b8083f5c4d1..cb2054fb5f0b 100644 --- a/vllm/lora/layers/utils.py +++ b/vllm/lora/layers/utils.py @@ -7,9 +7,22 @@ import torch import torch.nn as nn +from vllm import envs from vllm.model_executor.layers.fused_moe.fused_moe import try_get_optimal_moe_config +from vllm.platforms import current_platform from vllm.utils.math_utils import next_power_of_2 +_lora_aux_cuda_stream: torch.cuda.Stream | None = None + + +def _get_lora_aux_cuda_stream() -> torch.cuda.Stream | None: + if not envs.VLLM_LORA_ENABLE_DUAL_STREAM: + return None + global _lora_aux_cuda_stream + if _lora_aux_cuda_stream is None and current_platform.is_cuda_alike(): + _lora_aux_cuda_stream = torch.cuda.Stream() + return _lora_aux_cuda_stream + class LoRAMappingType(Enum): LANGUAGE = 1 diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 42f53b200fa3..fbad39f43a16 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -105,6 +105,742 @@ def _get_c_ptrs( _LORA_PTR_DICT: dict[tuple[int, ...], torch.tensor] = {} +# --------------------------------------------------------------------------- +# Fully-fused MoE-LoRA kernel (one-shot): shrink + expand combined into a single +# launch with the rank-dim intermediate kept in registers. Used by the fast +# path of `_fused_moe_lora` for `fully_sharded=False`. The legacy two-kernel +# path (`_fused_moe_lora_kernel` above) is retained for `fully_sharded=True` +# because that path needs to materialise the intermediate cache for an +# all_reduce / all_gather between shrink and expand. +# --------------------------------------------------------------------------- + + +@triton.heuristics({"EVEN_K": lambda args: args["K"] % args["BLOCK_K"] == 0}) +@triton.jit +def _fused_moe_lora_one_shot_kernel( + # ---- pointers ---- + x_ptr, + A_ptrs, + B_ptrs, + out_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + token_lora_mapping_ptr, + lora_ids_ptr, + adapter_enabled_ptr, + # ---- dims ---- + N, + K, + num_valid_tokens, + top_k_num, + max_loras, + # ---- strides ---- + stride_xm, + stride_xk, + stride_A_lora, + stride_A_expert, + stride_A_r, + stride_A_k, + stride_B_lora, + stride_B_expert, + stride_B_n, + stride_B_r, + stride_om, + stride_on, + stride_tl_, + stride_el, + # ---- scalar ---- + slice_n_offset, + # ---- constexpr (set per call) ---- + token_mapping_factor: tl.constexpr, + naive_block_assignment: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_R: tl.constexpr, + actual_rank: tl.constexpr, + NPID_FACTOR: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, +): + pid_full = tl.program_id(axis=0) + pid_m = pid_full // NPID_FACTOR + pid_n_outer = pid_full % NPID_FACTOR + slice_id = tl.program_id(axis=1) + lora_idx = tl.program_id(axis=2) + + # Resolve lora_id. + if naive_block_assignment: + token_idx_for_lora = pid_m // top_k_num + lora_id = tl.load(token_lora_mapping_ptr + token_idx_for_lora) + else: + lora_id = tl.load(lora_ids_ptr + lora_idx) + if lora_id < 0: + return + if lora_id >= max_loras: + return + enabled = tl.load(adapter_enabled_ptr + lora_id) + if enabled == 0: + return + + if not naive_block_assignment: + ntpp = tl.load(num_tokens_post_padded_ptr + lora_id) + if pid_m * BLOCK_M >= ntpp: + return + + # Resolve expert_id. + if naive_block_assignment: + expert_id = tl.load(expert_ids_ptr + pid_m) + else: + ind = lora_id * stride_el + pid_m + expert_id = tl.load( + expert_ids_ptr + ind, mask=ind < max_loras * stride_el, other=-1 + ) + if expert_id < 0: + return + + # Compute offs_token (flat token ids). + offs = tl.arange(0, BLOCK_M).to(tl.int64) + if naive_block_assignment: + offs_token = tl.where(offs == 0, pid_m, num_valid_tokens) + else: + offs_token_id = pid_m * BLOCK_M + offs + token_ind = stride_tl_ * lora_id + offs_token_id + offs_token = tl.load( + sorted_token_ids_ptr + token_ind, + mask=token_ind < max_loras * stride_tl_, + other=num_valid_tokens, + ) + token_mask = offs_token < num_valid_tokens + + # N range owned by this program. Splitting [0, N) into NPID_FACTOR + # contiguous outer blocks lets us scale parallelism for small batches. + n_per_outer = tl.cdiv(N, NPID_FACTOR) + n_lo = pid_n_outer * n_per_outer + n_hi = tl.minimum((pid_n_outer + 1) * n_per_outer, N) + if n_lo >= N: + return + + # Slice pointers. + cur_A_ptr = tl.load(A_ptrs + slice_id).to(tl.pointer_type(out_ptr.dtype.element_ty)) + cur_B_ptr = tl.load(B_ptrs + slice_id).to(tl.pointer_type(out_ptr.dtype.element_ty)) + + A_base = cur_A_ptr + lora_id * stride_A_lora + expert_id * stride_A_expert + B_base = cur_B_ptr + lora_id * stride_B_lora + expert_id * stride_B_expert + + # SHRINK: tmp[BLOCK_M, BLOCK_R] = x @ A^T, accumulated in fp32 registers. + offs_r = tl.arange(0, BLOCK_R) + rank_mask = offs_r < actual_rank + # Clamp rank offsets so OOB rows of A / B map to address 0; the mask + # zeros the loaded values. Required when BLOCK_R > actual_rank + # (e.g. rank=4 padded to 16) -- without clamping, tl.load would address + # the next expert's memory. + safe_offs_r = tl.where(rank_mask, offs_r, 0) + offs_k = tl.arange(0, BLOCK_K) + + offs_x_row = offs_token // token_mapping_factor + x_ptrs = x_ptr + offs_x_row[:, None] * stride_xm + offs_k[None, :] * stride_xk + a_ptrs = A_base + offs_k[:, None] * stride_A_k + safe_offs_r[None, :] * stride_A_r + + tmp = tl.zeros((BLOCK_M, BLOCK_R), dtype=tl.float32) + if EVEN_K: + for _ in range(0, K, BLOCK_K): + x = tl.load(x_ptrs, mask=token_mask[:, None], other=0.0) + a = tl.load(a_ptrs, mask=rank_mask[None, :], other=0.0) + tmp += tl.dot(x, a) + x_ptrs += BLOCK_K * stride_xk + a_ptrs += BLOCK_K * stride_A_k + else: + for kb in range(0, K, BLOCK_K): + k_remain = K - kb + k_mask = offs_k < k_remain + x = tl.load(x_ptrs, mask=token_mask[:, None] & k_mask[None, :], other=0.0) + a = tl.load(a_ptrs, mask=k_mask[:, None] & rank_mask[None, :], other=0.0) + tmp += tl.dot(x, a) + x_ptrs += BLOCK_K * stride_xk + a_ptrs += BLOCK_K * stride_A_k + + tmp_typed = tmp.to(out_ptr.dtype.element_ty) + + # EXPAND: out[tokens, n] += tmp @ B^T, looped over BLOCK_N tiles within + # this program's [n_lo, n_hi). The (offs_n < n_hi) mask is required + # whenever BLOCK_N > n_per_outer to keep adjacent outer blocks from + # writing into each other's columns. + if MUL_ROUTED_WEIGHT: + moe_w = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0.0).to( + tl.float32 + ) + + out_slice_base = out_ptr + slice_id * slice_n_offset + + for n_start in range(n_lo, n_hi, BLOCK_N): + offs_n = n_start + tl.arange(0, BLOCK_N) + n_mask = (offs_n < N) & (offs_n < n_hi) + + b_ptrs = ( + B_base + safe_offs_r[:, None] * stride_B_r + offs_n[None, :] * stride_B_n + ) + b = tl.load(b_ptrs, mask=rank_mask[:, None] & n_mask[None, :], other=0.0) + + acc = tl.dot(tmp_typed, b) # (BLOCK_M, BLOCK_N) fp32 + if MUL_ROUTED_WEIGHT: + acc = acc * moe_w[:, None] + + out_ptrs = ( + out_slice_base + + offs_token[:, None] * stride_om + + offs_n[None, :] * stride_on + ) + out_mask = token_mask[:, None] & n_mask[None, :] + if ADD_INPUTS: + prev = tl.load(out_ptrs, mask=out_mask, other=0.0) + tl.store(out_ptrs, prev + acc.to(out_ptr.dtype.element_ty), mask=out_mask) + else: + tl.store(out_ptrs, acc.to(out_ptr.dtype.element_ty), mask=out_mask) + + +def _run_fused_moe_lora_one_shot( + output: torch.Tensor, + qcurr_hidden_states: torch.Tensor, + lora_a_stacked: list[torch.Tensor], + lora_b_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor | None, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor | None, + token_lora_mapping: torch.Tensor, + max_lora_rank: int, + top_k_num: int, + lora_ids: torch.Tensor, + num_active_loras: torch.Tensor, + adapter_enabled: torch.Tensor, + mul_routed_weight: bool, + block_size_m: int, + add_inputs: bool = True, +) -> None: + """Fast-path wrapper: launches one fused shrink+expand kernel. + + The shape contract matches `_fused_moe_lora`. `output` has shape + `(num_tokens, top_k_num, num_slices * N_per_slice)`. When + `add_inputs=True` (default) the kernel reads-modifies-writes `output` + in place; when `add_inputs=False` the kernel overwrites `output` with + the LoRA delta only. The latter is used by the dual-stream path that + sums LoRA into the base output on a separate stream. + """ + num_slices = len(lora_a_stacked) + device = qcurr_hidden_states.device + + A0 = lora_a_stacked[0] + B0 = lora_b_stacked[0] + max_loras_w = A0.shape[0] + rank = A0.shape[2] + K = A0.shape[3] + N_per_slice = B0.shape[2] + + # rank padding is to next pow2 with a floor of 16 (tensor-core minimum + # K-dim). Beyond 128 the (BLOCK_M, BLOCK_R) accumulator outgrows the + # register file; rank tiling would be needed but is out of scope for + # this kernel. Tried floor=32 to double MMA density per K-step but it + # regressed across all M (+8 to +40%): the (64,32) fp32 accumulator + + # widened B tile pushed register count past spill threshold, lowering + # occupancy by more than the MMA gain saved. + assert rank <= 128, ( + f"fused_moe_lora_one_shot supports max_lora_rank<=128; got rank={rank}" + ) + BLOCK_R = max(triton.next_power_of_2(rank), 16) + + num_experts = A0.shape[1] + naive = sorted_token_ids is None + if sorted_token_ids is None: + EM_grid = topk_weights.numel() + BLOCK_M = 16 + stride_tl_ = 0 + stride_el = 0 + grid_lora_dim = 1 + else: + EM_grid = sorted_token_ids.shape[1] + # BLOCK_M must equal moe_lora_align_block_size's block_size. The + # caller passes that explicitly; deriving it from tensor shapes is + # unsafe because sorted_token_ids.shape[1] is the raw padded length + # (not necessarily a multiple of block_size — e.g. OLMoE prefill + # produces sorted=139200 with expert_ids=1088 and block_size=128). + # tl.arange and tl.dot need block_size_m to be a power of 2 and at + # least 16. The Python-side assertion gives a clearer error than + # the cryptic Triton compile failure. + assert block_size_m >= 16 and (block_size_m & (block_size_m - 1)) == 0, ( + f"shrink_block_size_m must be a power of 2 and >=16; got {block_size_m}" + ) + BLOCK_M = block_size_m + stride_tl_ = sorted_token_ids.stride(0) + stride_el = expert_ids.stride(0) + grid_lora_dim = int(num_active_loras.item()) + + # Empty-work guards: the grid would otherwise have a zero dimension, + # which Triton rejects. None of these is a hot path in production -- a + # batch with zero tokens, an EM_grid of zero, or zero active LoRAs all + # mean there's nothing to add to `output`. + if EM_grid == 0 or grid_lora_dim == 0 or num_slices == 0: + return + + token_mapping_factor = 1 if mul_routed_weight else top_k_num + + A_ptrs = _get_ptr(lora_a_stacked, device) + B_ptrs = _get_ptr(lora_b_stacked, device) + + # Flatten (num_tokens, top_k) → flat_token axis. The kernel addresses + # output via offs_token * stride_om, which is correct iff the dim-0 / + # dim-1 strides collapse cleanly: stride(0) == top_k * stride(1). All + # production callers pass contiguous output, so this always holds; the + # explicit check guards against future regressions where a non-trivial + # view (e.g. permute) would silently break in-place accumulation. + assert output.dim() == 3, f"output must be 3-D, got {output.shape}" + assert output.stride(0) == output.shape[1] * output.stride(1), ( + "fused_moe_lora_one_shot requires output.stride(0) == top_k*stride(1); " + f"got shape={output.shape} strides={output.stride()}" + ) + out_view = output.view(-1, output.shape[-1]) + M_blocks = triton.cdiv(EM_grid, BLOCK_M) if not naive else EM_grid + + # NPID_FACTOR heuristic: scale N-axis parallelism when base CTA count is + # short of saturating the SM array. Cap by the cost of redundant shrink. + sm_count = torch.cuda.get_device_properties(device).multi_processor_count + base_programs = max(M_blocks * num_slices * grid_lora_dim, 1) + shrink_ratio = K / max(K + N_per_slice, 1) + max_npid_by_budget = max(1, int(1.5 / max(shrink_ratio, 1e-3)) + 1) + target = 2 * sm_count + if base_programs >= int(1.5 * sm_count): + npid = 1 + else: + npid_occ = max(1, min(16, (target + base_programs - 1) // base_programs)) + npid = min(npid_occ, max_npid_by_budget) + npid = max(1, min(npid, max(1, N_per_slice // 128))) + + # Robust defaults across the prefill regime (H100/H200/B200, bf16/fp16). + # NPID > 1 is the small-M / under-saturated path -- more warps help + # amortise the inner-N expand loop. ns=3 instead of 4: GB200 ncu showed + # the 4-stage pipeline pushed register count to 168/thread and capped + # achieved occupancy at ~17% (3 blocks/SM, register-bound); ns=3 frees + # ~30 regs/thread which keeps a 4th block resident on small grids. + # Tried BLOCK_N=64 for w13 (N=192) to avoid the half-wasted second + # tile: regressed 11-29% because the "waste" was just masked stores + # (cheap) and the extra iteration added load + index overhead. + if npid > 1: + block_n, nw, ns = 128, 8, 3 + else: + block_n, nw, ns = 128, 4, 3 + # BLOCK_K choice: for hidden-sized K (≥256, i.e. the K=hidden_size + # shrink input on w13) force BLOCK_K=128 -- the wider tile halves the + # K-loop trip count and removes the scoreboard stalls that dominated + # M=16-64 on GB200 (kernel time -13% to -37% vs the work_per_expert + # heuristic which picked 64 for low-tokens-per-expert ratios). For + # small-K shapes (e.g. w2 with K=192 where the down-proj reads the + # MoE intermediate) keep the work_per_expert heuristic: BLOCK_K=128 + # would force the EVEN_K=False masked path and add no K-loop savings + # (K/64=3 vs K/128=2 masked) while inflating per-program startup. + if K >= 256: + block_k = 128 + else: + work_per_expert = topk_weights.numel() / max(num_experts, 1) + block_k = 128 if work_per_expert >= 16 else 64 + + grid = (M_blocks * npid, num_slices, grid_lora_dim) + + _fused_moe_lora_one_shot_kernel[grid]( + qcurr_hidden_states, + A_ptrs, + B_ptrs, + out_view, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + token_lora_mapping, + lora_ids, + adapter_enabled, + N_per_slice, + K, + topk_weights.numel(), + top_k_num, + max_loras_w, + qcurr_hidden_states.stride(0), + qcurr_hidden_states.stride(1), + A0.stride(0), + A0.stride(1), + A0.stride(2), + A0.stride(3), + B0.stride(0), + B0.stride(1), + B0.stride(2), + B0.stride(3), + out_view.stride(0), + out_view.stride(1), + stride_tl_, + stride_el, + N_per_slice, + token_mapping_factor=token_mapping_factor, + naive_block_assignment=naive, + MUL_ROUTED_WEIGHT=mul_routed_weight, + BLOCK_M=BLOCK_M, + BLOCK_R=BLOCK_R, + actual_rank=rank, + NPID_FACTOR=npid, + BLOCK_N=block_n, + BLOCK_K=block_k, + ADD_INPUTS=add_inputs, + num_warps=nw, + num_stages=ns, + ) + + +# --------------------------------------------------------------------------- +# Small-batch (decode-style) fused MoE-LoRA kernel — sub-path of the +# one_shot fast path. +# --------------------------------------------------------------------------- + + +@triton.heuristics({"EVEN_K": lambda args: args["K"] % args["BLOCK_K"] == 0}) +@triton.jit +def _fused_moe_lora_small_batch_kernel( + # ---- pointers ---- + x_ptr, + A_ptrs, + B_ptrs, + out_ptr, + topk_weights_ptr, + expert_ids_ptr, # (num_tokens * top_k_num,) + token_lora_mapping_ptr, # (num_tokens,) + adapter_enabled_ptr, + # ---- dims ---- + N, + K, + top_k_num, + max_loras, + work_total, # = pair_slices * n_chunks_per_pair_slice + pair_slices, # = num_tokens * top_k_num * NUM_SLICES + # ---- strides ---- + stride_xm, + stride_xk, + stride_A_lora, + stride_A_expert, + stride_A_r, + stride_A_k, + stride_B_lora, + stride_B_expert, + stride_B_n, + stride_B_r, + stride_om, + stride_on, + # ---- scalar (runtime ints, NOT constexpr) ---- + # n_tiles_per_program / n_chunks_per_pair_slice are deliberately + # runtime: each distinct value would otherwise trigger a fresh Triton + # compile -> fresh kernel binary -> fresh CUDA graph instance per + # batch size. Production traces showed that variant explosion adding + # ~5.9k graph instantiations on top of legacy. Runtime args mean one + # shared binary across all chunk sizes. + slice_n_offset, + n_tiles_per_program, + n_chunks_per_pair_slice, + # ---- constexpr ---- + token_mapping_factor: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + ADD_INPUTS: tl.constexpr, + BLOCK_R: tl.constexpr, + actual_rank: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + NUM_SLICES: tl.constexpr, + EVEN_K: tl.constexpr, +): + """Persistent fused MoE-LoRA kernel for naive_block_assignment inputs. + + Each program owns one (pair × slice × n_chunk) work item. A "chunk" + covers `n_tiles_per_program` consecutive output-N tiles, all of which + share a single shrink — so the rank-vector is computed once per + program and the A weights for that (lora, expert, slice) are loaded + once instead of n_tiles_per_program times. + + The wrapper picks `n_tiles_per_program` to keep the grid close to + 2*SM_count: at very small batch (work_total ≤ SM_count) the chunk + size collapses to 1 and behaviour matches a per-tile GEMV; as batch + grows the chunk grows so we trade some N-axis parallelism for shrink + reuse. When `work_total` exceeds the launched grid, the outer stride + loop drains the leftover work units serially. + """ + pid = tl.program_id(axis=0) + num_programs = tl.num_programs(axis=0) + + offs_r = tl.arange(0, BLOCK_R) + rank_mask = offs_r < actual_rank + # Clamp OOB rank lanes so they address row 0 of A/B; the mask zeros + # the loaded values. Required when BLOCK_R > actual_rank (e.g. rank=4 + # padded to 16) -- without clamping, tl.load would address the next + # expert's memory. + safe_offs_r = tl.where(rank_mask, offs_r, 0) + offs_k = tl.arange(0, BLOCK_K) + + # Persistent stride loop: when grid < work_total each program walks + # multiple work items. When grid == work_total the loop runs exactly + # once and the kernel degenerates to the per-tile GEMV. + for work_id in range(pid, work_total, num_programs): + n_chunk_idx = work_id % n_chunks_per_pair_slice + pair_slice_idx = work_id // n_chunks_per_pair_slice + # NUM_SLICES is constexpr (typ. 1 or 2) so divmod folds. + pair_idx = pair_slice_idx // NUM_SLICES + slice_id = pair_slice_idx % NUM_SLICES + + # Resolve lora_id / expert_id; skip the body for inactive lanes. + # Using a single `valid` flag instead of early `return` keeps the + # outer stride loop alive — `return` would exit the whole program + # and skip later work items assigned to this SM. + token_idx = pair_idx // top_k_num + lora_id = tl.load(token_lora_mapping_ptr + token_idx) + valid = (lora_id >= 0) & (lora_id < max_loras) + enabled = tl.load(adapter_enabled_ptr + tl.where(valid, lora_id, 0)) + valid = valid & (enabled != 0) + expert_id = tl.load(expert_ids_ptr + pair_idx) + valid = valid & (expert_id >= 0) + + if valid: + cur_A_ptr = tl.load(A_ptrs + slice_id).to( + tl.pointer_type(out_ptr.dtype.element_ty) + ) + cur_B_ptr = tl.load(B_ptrs + slice_id).to( + tl.pointer_type(out_ptr.dtype.element_ty) + ) + A_base = cur_A_ptr + lora_id * stride_A_lora + expert_id * stride_A_expert + B_base = cur_B_ptr + lora_id * stride_B_lora + expert_id * stride_B_expert + + x_row = pair_idx // token_mapping_factor + x_row_ptr = x_ptr + x_row * stride_xm + + # SHRINK GEMV (once per program; reused across n_tiles_per_program + # expand tiles below). Sum-reduction over BLOCK_K with fp32 + # accumulator — same precision path as the one_shot kernel. + rank_vec = tl.zeros((BLOCK_R,), dtype=tl.float32) + if EVEN_K: + for kb in range(0, K, BLOCK_K): + cur_k = kb + offs_k + x_tile = tl.load(x_row_ptr + cur_k * stride_xk).to(tl.float32) + a_tile = tl.load( + A_base + + safe_offs_r[:, None] * stride_A_r + + cur_k[None, :] * stride_A_k, + mask=rank_mask[:, None], + other=0.0, + ).to(tl.float32) + rank_vec += tl.sum(a_tile * x_tile[None, :], axis=1) + else: + for kb in range(0, K, BLOCK_K): + cur_k = kb + offs_k + k_mask = cur_k < K + x_tile = tl.load( + x_row_ptr + cur_k * stride_xk, mask=k_mask, other=0.0 + ).to(tl.float32) + a_tile = tl.load( + A_base + + safe_offs_r[:, None] * stride_A_r + + cur_k[None, :] * stride_A_k, + mask=rank_mask[:, None] & k_mask[None, :], + other=0.0, + ).to(tl.float32) + rank_vec += tl.sum(a_tile * x_tile[None, :], axis=1) + + # EXPAND: walk n_tiles_per_program consecutive output-N tiles + # using the same rank_vec. The loop is a runtime range (not + # tl.static_range) so a single compiled kernel handles every + # chunk size — see the note on the kernel signature. + n_tile_start = n_chunk_idx * n_tiles_per_program + out_row_ptr = out_ptr + slice_id * slice_n_offset + pair_idx * stride_om + + if MUL_ROUTED_WEIGHT: + moe_w = tl.load(topk_weights_ptr + pair_idx).to(tl.float32) + + for nt in range(n_tiles_per_program): + n_lo = (n_tile_start + nt) * BLOCK_N + if n_lo < N: + offs_n = n_lo + tl.arange(0, BLOCK_N) + n_mask = offs_n < N + b_tile = tl.load( + B_base + + offs_n[:, None] * stride_B_n + + safe_offs_r[None, :] * stride_B_r, + mask=n_mask[:, None] & rank_mask[None, :], + other=0.0, + ).to(tl.float32) + out_tile = tl.sum(b_tile * rank_vec[None, :], axis=1) + + if MUL_ROUTED_WEIGHT: + out_tile = out_tile * moe_w + + out_ptrs = out_row_ptr + offs_n * stride_on + if ADD_INPUTS: + prev = tl.load(out_ptrs, mask=n_mask, other=0.0).to(tl.float32) + tl.store( + out_ptrs, + (prev + out_tile).to(out_ptr.dtype.element_ty), + mask=n_mask, + ) + else: + tl.store( + out_ptrs, + out_tile.to(out_ptr.dtype.element_ty), + mask=n_mask, + ) + + +def _pick_small_batch_chunk(pair_slices: int, N_tiles: int, sm_count: int) -> int: + """Pick `n_tiles_per_program` so the launched grid stays near + 2*SM_count. + + Sizes for occupancy first (more programs in flight → better latency + hiding for the K-loop A/x loads). Once the per-tile grid already + exceeds 2*SM_count we increase the chunk size to amortise the shrink + cost — at that point the GPU is saturated by per-program work and + packing more tiles per program lets the rank_vec be reused. + """ + target_grid = max(1, 2 * sm_count) + total_work = pair_slices * N_tiles + if total_work <= target_grid: + return 1 + ntpp = (total_work + target_grid - 1) // target_grid + return min(ntpp, N_tiles) + + +def _run_fused_moe_lora_small_batch( + output: torch.Tensor, + qcurr_hidden_states: torch.Tensor, + lora_a_stacked: list[torch.Tensor], + lora_b_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + expert_ids_flat: torch.Tensor, # (num_tokens * top_k_num,) + token_lora_mapping: torch.Tensor, + top_k_num: int, + adapter_enabled: torch.Tensor, + mul_routed_weight: bool, + add_inputs: bool = True, +) -> None: + """Small-batch GEMV-style wrapper. Naive-block-assignment inputs only. + + Shape contract matches `_run_fused_moe_lora_one_shot`: `output` is + `(num_tokens, top_k_num, num_slices * N_per_slice)` with + contiguous-style strides, `expert_ids_flat` is the flattened + `topk_ids` of shape `(num_tokens * top_k_num,)`, and the + rank-padded LoRA weights live in `lora_a_stacked` / + `lora_b_stacked`. + + The kernel is persistent over (pair × slice × n_chunk) work items — + each program does one shrink and reuses the rank vector across + `n_tiles_per_program` expand tiles. The chunk size scales with the + pair-slice count so very small batches keep per-tile parallelism + while medium batches cut redundant shrinks. + """ + num_slices = len(lora_a_stacked) + device = qcurr_hidden_states.device + + A0 = lora_a_stacked[0] + B0 = lora_b_stacked[0] + max_loras_w = A0.shape[0] + rank = A0.shape[2] + K = A0.shape[3] + N_per_slice = B0.shape[2] + + # Rank padding: floor 16 (tensor-core min K), ceil to next pow2. The + # ≤64 cap is set conservatively for the prototype: at rank 64 the + # per-program register footprint is rank_vec(64 fp32) + b_tile(BLOCK_N + # × 64 fp32) = e.g. 128*64*4 = 32 KiB, comfortably within the 64 KiB + # register file even with num_warps=8. Doubling to 128 would push us + # against the limit and require shared-memory staging. + assert rank <= 64, f"fused_moe_lora_small_batch supports rank<=64; got rank={rank}" + BLOCK_R = max(triton.next_power_of_2(rank), 16) + + num_tokens = topk_weights.shape[0] + M_grid = num_tokens * top_k_num + if M_grid == 0 or num_slices == 0: + return + + token_mapping_factor = 1 if mul_routed_weight else top_k_num + + A_ptrs = _get_ptr(lora_a_stacked, device) + B_ptrs = _get_ptr(lora_b_stacked, device) + + assert output.dim() == 3, f"output must be 3-D, got {output.shape}" + assert output.stride(0) == output.shape[1] * output.stride(1), ( + "fused_moe_lora_small_batch requires output.stride(0) == " + f"top_k*stride(1); got shape={output.shape} strides={output.stride()}" + ) + out_view = output.view(-1, output.shape[-1]) + + # Block sizes. BLOCK_N=128 matches the one_shot's expand tile and gives + # 6-24 N tiles for typical N ∈ [768, 3072], enough to saturate the SM + # array once M_grid * num_slices reaches ~SM_count. BLOCK_K=128 halves + # the K-loop trip count vs 64 and pays for itself once K ≥ 1024 (the + # only regime we care about — hidden sizes are always large here). + BLOCK_N = 128 + BLOCK_K = 128 + nw = 4 + ns = 3 + + N_tiles = triton.cdiv(N_per_slice, BLOCK_N) + pair_slices = M_grid * num_slices + + sm_count = torch.cuda.get_device_properties(device).multi_processor_count + n_tiles_per_program = _pick_small_batch_chunk(pair_slices, N_tiles, sm_count) + n_chunks = triton.cdiv(N_tiles, n_tiles_per_program) + work_total = pair_slices * n_chunks + + # Grid sizing: keep parallelism uncapped when work_total is small (so + # very small batches still spread across SMs); cap at 2*SM_count once + # we have plenty of work, letting the in-kernel stride loop drain the + # remainder. + grid_size = min(work_total, max(1, 2 * sm_count)) + grid = (grid_size,) + + _fused_moe_lora_small_batch_kernel[grid]( + qcurr_hidden_states, + A_ptrs, + B_ptrs, + out_view, + topk_weights, + expert_ids_flat, + token_lora_mapping, + adapter_enabled, + N_per_slice, + K, + top_k_num, + max_loras_w, + work_total, + pair_slices, + qcurr_hidden_states.stride(0), + qcurr_hidden_states.stride(1), + A0.stride(0), + A0.stride(1), + A0.stride(2), + A0.stride(3), + B0.stride(0), + B0.stride(1), + B0.stride(2), + B0.stride(3), + out_view.stride(0), + out_view.stride(1), + N_per_slice, + n_tiles_per_program, + n_chunks, + token_mapping_factor=token_mapping_factor, + MUL_ROUTED_WEIGHT=mul_routed_weight, + ADD_INPUTS=add_inputs, + BLOCK_R=BLOCK_R, + actual_rank=rank, + BLOCK_N=BLOCK_N, + BLOCK_K=BLOCK_K, + NUM_SLICES=num_slices, + num_warps=nw, + num_stages=ns, + ) + + def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device): """ `_LORA_PTR_DICT` collects the required information during `profile_run`, @@ -706,6 +1442,7 @@ def _fused_moe_lora( mul_routed_weight: bool = False, fully_sharded: bool = False, offset: int = 0, + add_inputs: bool = True, ) -> None: assert len(lora_a_stacked) == len(lora_b_stacked) > 0 assert topk_weights.dim() == qcurr_hidden_states.dim() == 2 @@ -728,6 +1465,59 @@ def _fused_moe_lora( ) assert output.shape[0] == topk_weights.shape[0] assert top_k_num == topk_weights.shape[1] + + # Fast path: single fused kernel + if not fully_sharded: + M_pairs = topk_weights.numel() + if ( + sorted_token_ids is None + and max_lora_rank <= 64 + and M_pairs * max_lora_rank <= 1024 + ): + _run_fused_moe_lora_small_batch( + output, + qcurr_hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + expert_ids, + token_lora_mapping, + top_k_num, + adapter_enabled, + mul_routed_weight, + add_inputs=add_inputs, + ) + return + # shrink/expand BLOCK_SIZE_M must match the block_size that + # moe_lora_align_block_size used; both shrink and expand pass the + # same value (asserted by `shrink_block_size_m == expand_block_size_m` + # below). + _run_fused_moe_lora_one_shot( + output, + qcurr_hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + token_lora_mapping, + max_lora_rank, + top_k_num, + lora_ids, + num_active_loras, + adapter_enabled, + mul_routed_weight, + shrink_block_size_m, + add_inputs=add_inputs, + ) + return + + assert add_inputs, ( + "fused_moe_lora(add_inputs=False) is only supported on the " + "fully_sharded=False fast path" + ) + device = qcurr_hidden_states.device num_slices = len(lora_a_stacked) w1_lora_b_stacked = lora_b_stacked[0] @@ -894,6 +1684,7 @@ def _fused_moe_lora_fake( mul_routed_weight: bool = False, fully_sharded: bool = False, offset: int = 0, + add_inputs: bool = True, ) -> None: return diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 0448a6d00cda..086712991907 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -514,6 +514,7 @@ def add_lora_w13( num_slices: int, fully_sharded: bool, use_tuned_config: bool, + add_inputs: bool = True, token_lora_mapping: torch.Tensor | None = None, ) -> tuple[ torch.Tensor | None, @@ -554,6 +555,7 @@ def add_lora_w2( fully_sharded: bool, tp_rank: int, use_tuned_config: bool, + add_inputs: bool = True, ) -> None: """Apply w2 LoRA to y (intermediate_cache3) in-place before moe_sum. diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index bf951e074949..87500ec3ec25 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -435,6 +435,7 @@ def add_lora_fused_moe( fully_sharded: bool = False, offset: int = 0, token_lora_mapping: torch.Tensor | None = None, + add_inputs: bool = True, ): """ Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer. @@ -484,6 +485,7 @@ def add_lora_fused_moe( mul_routed_weight, fully_sharded, offset, + add_inputs, ) def add_lora_w13( @@ -506,6 +508,7 @@ def add_lora_w13( num_slices: int, fully_sharded: bool, use_tuned_config: bool, + add_inputs: bool = True, token_lora_mapping: torch.Tensor | None = None, ) -> tuple[ torch.Tensor | None, @@ -610,6 +613,7 @@ def add_lora_w13( adapter_enabled, fully_sharded=fully_sharded, token_lora_mapping=token_lora_mapping, + add_inputs=add_inputs, ) return ( @@ -640,6 +644,7 @@ def add_lora_w2( fully_sharded: bool, tp_rank: int, use_tuned_config: bool, + add_inputs: bool = True, ) -> None: import functools @@ -722,4 +727,5 @@ def add_lora_w2( fully_sharded=fully_sharded, offset=offset, token_lora_mapping=token_lora_mapping, + add_inputs=add_inputs, ) diff --git a/vllm/lora/punica_wrapper/punica_xpu.py b/vllm/lora/punica_wrapper/punica_xpu.py index 58316cb75970..7fdadad09391 100755 --- a/vllm/lora/punica_wrapper/punica_xpu.py +++ b/vllm/lora/punica_wrapper/punica_xpu.py @@ -390,6 +390,7 @@ def add_lora_fused_moe( fully_sharded: bool = False, offset: int = 0, token_lora_mapping: torch.Tensor | None = None, + add_inputs: bool = True, ): """ Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer. @@ -439,6 +440,7 @@ def add_lora_fused_moe( mul_routed_weight, fully_sharded, offset, + add_inputs, ) def add_lora_w13( @@ -461,6 +463,7 @@ def add_lora_w13( num_slices: int, fully_sharded: bool, use_tuned_config: bool, + add_inputs: bool = True, token_lora_mapping: torch.Tensor | None = None, ) -> tuple[ torch.Tensor | None, @@ -594,6 +597,7 @@ def add_lora_w2( fully_sharded: bool, tp_rank: int, use_tuned_config: bool, + add_inputs: bool = True, ) -> None: import functools diff --git a/vllm/model_executor/layers/fused_moe/experts/lora_context.py b/vllm/model_executor/layers/fused_moe/experts/lora_context.py index ab1f0bfc1476..404457bb34bd 100644 --- a/vllm/model_executor/layers/fused_moe/experts/lora_context.py +++ b/vllm/model_executor/layers/fused_moe/experts/lora_context.py @@ -43,6 +43,16 @@ class MoELoRAContext: # try_get_optimal_moe_lora_config for Triton kernel tile configs. use_tuned_config: bool + # Optional dual-stream support for overlapping each (base GEMM, LoRA) + # pair. When aux_stream is None, the experts.apply() path runs the + # original sequential schedule. When set, base GEMM runs on the default + # stream and the LoRA fast-path writes the delta into a fresh buffer on + # aux_stream, which the default stream sums in afterwards. + # Events are paired one-per-overlap-pair: events[0,1] for w13, + # events[2,3] for w2, so the two pairs do not race on the same event. + aux_stream: torch.cuda.Stream | None = None + events: tuple[torch.cuda.Event, ...] | None = None + # Per-rank token→LoRA mapping after EP dispatch. Set by # FusedMoEPrepareAndFinalizeModular.prepare() when EP+LoRA is active, read # by LoRAExpertsMixin helpers in place of punica_wrapper's global mapping. diff --git a/vllm/model_executor/layers/fused_moe/experts/lora_experts_mixin.py b/vllm/model_executor/layers/fused_moe/experts/lora_experts_mixin.py index 2a680909d5f6..a47145b9493f 100644 --- a/vllm/model_executor/layers/fused_moe/experts/lora_experts_mixin.py +++ b/vllm/model_executor/layers/fused_moe/experts/lora_experts_mixin.py @@ -45,6 +45,7 @@ def apply_w13_lora( w2: torch.Tensor, num_tokens: int, top_k_num: int, + add_inputs: bool = True, ) -> tuple[ torch.Tensor | None, torch.Tensor | None, @@ -70,6 +71,7 @@ def apply_w13_lora( lora_context.w13_num_slices, lora_context.fully_sharded, lora_context.use_tuned_config, + add_inputs=add_inputs, token_lora_mapping=lora_context.local_token_lora_mapping, ) @@ -88,6 +90,7 @@ def apply_w2_lora( w1: torch.Tensor, w2: torch.Tensor, top_k_num: int, + add_inputs: bool = True, ) -> None: lora_context.punica_wrapper.add_lora_w2( y, @@ -109,4 +112,5 @@ def apply_w2_lora( lora_context.fully_sharded, lora_context.tp_rank, lora_context.use_tuned_config, + add_inputs=add_inputs, ) diff --git a/vllm/model_executor/layers/fused_moe/experts/triton_moe.py b/vllm/model_executor/layers/fused_moe/experts/triton_moe.py index 1ba39b35fd4e..45b95b341020 100644 --- a/vllm/model_executor/layers/fused_moe/experts/triton_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/triton_moe.py @@ -48,6 +48,7 @@ ) from vllm.platforms import current_platform from vllm.triton_utils import tl +from vllm.utils.multi_stream_utils import maybe_execute_in_parallel class TritonExperts(LoRAExpertsMixin, mk.FusedMoEExpertsModular): @@ -247,55 +248,99 @@ def apply( ) ) - invoke_fused_moe_triton_kernel( - hidden_states, - w1, - intermediate_cache1, - a1q_scale if a1q_scale is not None else self.a1_scale, - self.w1_scale, - None, # topk_weights - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - False, # mul_routed_weights - top_k_num, - config, - compute_type=compute_type, - use_fp8_w8a8=self.quant_config.use_fp8_w8a8, - use_int8_w8a8=self.quant_config.use_int8_w8a8, - use_int8_w8a16=self.quant_config.use_int8_w8a16, - use_int4_w4a16=self.quant_config.use_int4_w4a16, - per_channel_quant=self.per_act_token_quant, - block_shape=self.block_shape, - B_bias=self.w1_bias, - ) + # LoRA w13: applied to intermediate_cache1 before activation. When + # the LoRA layer requested a dual-stream schedule, we run base w13 + # GEMM on the default stream and the LoRA fast-path on aux_stream; + # the LoRA writes its delta into a fresh zero buffer (add_inputs= + # False) and we sum it into intermediate_cache1 after both finish. - # LoRA w13: applied to intermediate_cache1 before activation, using - # hidden_states as the lora_a input. moe_lora_align_block_size is - # called once here and results reused for the w2 LoRA below. sorted_token_ids_lora = None expert_ids_lora = None num_tokens_post_padded_lora = None token_lora_mapping = None lora_context = self._lora_context - if lora_context is not None: + + def _base_w13_fn(): + invoke_fused_moe_triton_kernel( + hidden_states, + w1, + intermediate_cache1, + a1q_scale if a1q_scale is not None else self.a1_scale, + self.w1_scale, + None, # topk_weights + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, # mul_routed_weights + top_k_num, + config, + compute_type=compute_type, + use_fp8_w8a8=self.quant_config.use_fp8_w8a8, + use_int8_w8a8=self.quant_config.use_int8_w8a8, + use_int8_w8a16=self.quant_config.use_int8_w8a16, + use_int4_w4a16=self.quant_config.use_int4_w4a16, + per_channel_quant=self.per_act_token_quant, + block_shape=self.block_shape, + B_bias=self.w1_bias, + ) + + if lora_context is not None and lora_context.aux_stream is not None: + # add_inputs=False: kernel overwrites lora_delta_w13. zeros (not + # empty) so untouched rows -- e.g. blocks where every program + # early-exits because lora_id<0 -- stay at zero and the trailing + # add_() is a no-op there. + lora_delta_w13 = torch.zeros_like(intermediate_cache1) + + def _lora_w13_fn(): + return self.apply_w13_lora( + lora_context, + y=lora_delta_w13, + x=hidden_states, + topk_ids=topk_ids, + topk_weights=topk_weights, + expert_map=expert_map, + w1=w1, + w2=w2, + num_tokens=num_tokens, + top_k_num=top_k_num, + add_inputs=False, + ) + + assert lora_context.events is not None + _, lora_meta = maybe_execute_in_parallel( + _base_w13_fn, + _lora_w13_fn, + lora_context.events[0], + lora_context.events[1], + lora_context.aux_stream, + ) ( sorted_token_ids_lora, expert_ids_lora, num_tokens_post_padded_lora, token_lora_mapping, - ) = self.apply_w13_lora( - lora_context, - y=intermediate_cache1, - x=hidden_states, - topk_ids=topk_ids, - topk_weights=topk_weights, - expert_map=expert_map, - w1=w1, - w2=w2, - num_tokens=num_tokens, - top_k_num=top_k_num, - ) + ) = lora_meta + intermediate_cache1.add_(lora_delta_w13) + else: + _base_w13_fn() + if lora_context is not None: + ( + sorted_token_ids_lora, + expert_ids_lora, + num_tokens_post_padded_lora, + token_lora_mapping, + ) = self.apply_w13_lora( + lora_context, + y=intermediate_cache1, + x=hidden_states, + topk_ids=topk_ids, + topk_weights=topk_weights, + expert_map=expert_map, + w1=w1, + w2=w2, + num_tokens=num_tokens, + top_k_num=top_k_num, + ) a2q_scale: torch.Tensor | None = None @@ -328,48 +373,82 @@ def apply( quantization_emulation=self.quantization_emulation, ) - invoke_fused_moe_triton_kernel( - qintermediate_cache2, - w2, - intermediate_cache3, - a2q_scale, - self.w2_scale, - topk_weights, - sorted_token_ids, - expert_ids, - num_tokens_post_padded, - not apply_router_weight_on_input, - 1, - config, - compute_type=compute_type, - use_fp8_w8a8=self.quant_config.use_fp8_w8a8, - use_int8_w8a8=self.quant_config.use_int8_w8a8, - use_int8_w8a16=self.quant_config.use_int8_w8a16, - use_int4_w4a16=self.quant_config.use_int4_w4a16, - per_channel_quant=self.per_act_token_quant, - block_shape=self.block_shape, - B_bias=self.w2_bias, - ) - # LoRA w2: applied to intermediate_cache3 before moe_sum, using the # unquantized intermediate_cache2 as the lora_a input. Reuses the - # sorted_token_ids_lora computed above. - if lora_context is not None: - self.apply_w2_lora( - lora_context, - y=intermediate_cache3, - x=intermediate_cache2, - topk_weights=topk_weights, - sorted_token_ids_lora=sorted_token_ids_lora, - expert_ids_lora=expert_ids_lora, - num_tokens_post_padded_lora=num_tokens_post_padded_lora, - token_lora_mapping=token_lora_mapping, - num_tokens=num_tokens, - w1=w1, - w2=w2, - top_k_num=top_k_num, + # sorted_token_ids_lora computed above. Same dual-stream pattern as + # the w13 pair: base GEMM on default stream, LoRA delta on aux, + # join via .add_() into intermediate_cache3. + def _base_w2_fn(): + invoke_fused_moe_triton_kernel( + qintermediate_cache2, + w2, + intermediate_cache3, + a2q_scale, + self.w2_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + not apply_router_weight_on_input, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=self.quant_config.use_fp8_w8a8, + use_int8_w8a8=self.quant_config.use_int8_w8a8, + use_int8_w8a16=self.quant_config.use_int8_w8a16, + use_int4_w4a16=self.quant_config.use_int4_w4a16, + per_channel_quant=self.per_act_token_quant, + block_shape=self.block_shape, + B_bias=self.w2_bias, ) + if lora_context is not None and lora_context.aux_stream is not None: + lora_delta_w2 = torch.zeros_like(intermediate_cache3) + + def _lora_w2_fn(): + self.apply_w2_lora( + lora_context, + y=lora_delta_w2, + x=intermediate_cache2, + topk_weights=topk_weights, + sorted_token_ids_lora=sorted_token_ids_lora, + expert_ids_lora=expert_ids_lora, + num_tokens_post_padded_lora=num_tokens_post_padded_lora, + token_lora_mapping=token_lora_mapping, + num_tokens=num_tokens, + w1=w1, + w2=w2, + top_k_num=top_k_num, + add_inputs=False, + ) + + assert lora_context.events is not None + maybe_execute_in_parallel( + _base_w2_fn, + _lora_w2_fn, + lora_context.events[2], + lora_context.events[3], + lora_context.aux_stream, + ) + intermediate_cache3.add_(lora_delta_w2) + else: + _base_w2_fn() + if lora_context is not None: + self.apply_w2_lora( + lora_context, + y=intermediate_cache3, + x=intermediate_cache2, + topk_weights=topk_weights, + sorted_token_ids_lora=sorted_token_ids_lora, + expert_ids_lora=expert_ids_lora, + num_tokens_post_padded_lora=num_tokens_post_padded_lora, + token_lora_mapping=token_lora_mapping, + num_tokens=num_tokens, + w1=w1, + w2=w2, + top_k_num=top_k_num, + ) + # separate function is required for MoE + LoRA self.moe_sum(intermediate_cache3, output) diff --git a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py index 39b40d1abe6a..abd974a7c0b2 100644 --- a/vllm/model_executor/layers/fused_moe/runner/moe_runner.py +++ b/vllm/model_executor/layers/fused_moe/runner/moe_runner.py @@ -164,6 +164,8 @@ def _moe_forward_shared_fake( return shared_out, fused_out +# NOTE: `moe_forward` and `moe_forward_shared` being opaque custom ops is a +# load-bearing assumption for the MoE-LoRA dual-stream path. direct_register_custom_op( op_name="moe_forward", op_func=_moe_forward,