diff --git a/python/sglang/srt/lora/triton_ops/__init__.py b/python/sglang/srt/lora/triton_ops/__init__.py index bf7745a284c4..7f96837cd367 100644 --- a/python/sglang/srt/lora/triton_ops/__init__.py +++ b/python/sglang/srt/lora/triton_ops/__init__.py @@ -2,6 +2,7 @@ from .chunked_sgmv_expand import chunked_sgmv_lora_expand_forward from .chunked_sgmv_shrink import chunked_sgmv_lora_shrink_forward from .embedding_lora_a import embedding_lora_a_fwd +from .fused_moe_lora_kernel import fused_moe_lora from .gate_up_lora_b import gate_up_lora_b_fwd from .qkv_lora_b import qkv_lora_b_fwd from .sgemm_lora_a import sgemm_lora_a_fwd @@ -14,6 +15,7 @@ "sgemm_lora_b_fwd", "chunked_sgmv_lora_shrink_forward", "chunked_sgmv_lora_expand_forward", + "fused_moe_lora", "chunked_embedding_lora_a_forward", "embedding_lora_a_fwd", ] diff --git a/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py new file mode 100644 index 000000000000..0d5ae426d31f --- /dev/null +++ b/python/sglang/srt/lora/triton_ops/fused_moe_lora_kernel.py @@ -0,0 +1,690 @@ +# Temporarily adapted from https://github.com/vllm-project/vllm/blob/main/vllm/lora/ops/triton_ops/fused_moe_lora_op.py, will optimize in future refactor + +import torch +import triton +import triton.language as tl + +from sglang.srt.distributed import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) +from sglang.srt.utils.common import is_blackwell_supported, is_sm90_supported + +# Import SGLang's standard PDL support detection + + +_LORA_PTR_DICT: dict[tuple[int, ...], torch.Tensor] = {} + + +def _get_ptr(lora_weights: list[torch.Tensor], device: torch.device): + """ + `_LORA_PTR_DICT` collects the required information during `profile_run`, + After this, it remains constant and subsequent usage is through LUT. + Refer to: + https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py + """ + key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights) + + if (ptr_tensor := _LORA_PTR_DICT.get(key)) is not None: + return ptr_tensor + + tensor_ptrs = [] + for lora_weight in lora_weights: + tensor_ptrs.append(lora_weight.data_ptr()) + ptr_tensor = torch.tensor(tensor_ptrs, device=device, dtype=torch.uint64) + + _LORA_PTR_DICT[key] = ptr_tensor + return _LORA_PTR_DICT.get(key) + + +@triton.jit( + do_not_specialize=[ + "num_valid_tokens", + "EM", + "stride_tl", + "stride_el", + "slice_a_size", + "slice_c_size", + ] +) +def _fused_moe_lora_kernel( + a_ptr, + b_ptr, + c_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + num_experts, + lora_ids, + adapter_enabled, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_bl, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_tl, + stride_el, + slice_a_size, + slice_c_size, + # Meta-parameters + num_slice_a: tl.constexpr, + num_slice_c: tl.constexpr, + top_k: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + SPLIT_K: tl.constexpr, + USE_GDC: tl.constexpr, + launch_pdl: tl.constexpr, + IS_PRIMARY: tl.constexpr, +): + pid = tl.program_id(axis=0) + slice_id = tl.program_id(axis=1) + lora_idx = tl.program_id(axis=2) + lora_id = tl.load(lora_ids + lora_idx) + + if lora_id == -1: + # Early exit for the no-lora case. + return + moe_enabled = tl.load(adapter_enabled + lora_id) + if moe_enabled == 0: + # Early exit for the no moe lora case. + return + max_loras = tl.num_programs(axis=2) + grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) + + # calculate pid_m,pid_n + pid_sk = pid % SPLIT_K + pid_m_n = pid // SPLIT_K + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid_m_n // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m) + pid_n = (pid_m_n % num_pid_in_group) // group_size_m + + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + # get the expert_id to process curr shard + ind = lora_id * stride_el + pid_m + expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1) + if expert_id == -1: + return + + # get a_ptr,b_ptr,c_ptr + cur_a_ptr = a_ptr + (slice_id % num_slice_a) * slice_a_size + cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty)) + cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + # ================================================================= secure + + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + token_ind = stride_tl * lora_id + offs_token_id + offs_token = tl.load( + sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0 + ) + token_mask = offs_token < num_valid_tokens + + # ================================================================= secure + + # get a_ptrs,b_ptrs + a_ptrs = cur_a_ptr + ( + offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak + ) + + b_ptrs = ( + cur_b_ptr + + lora_id * stride_bl + + expert_id * stride_be + + offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn + ) + + if USE_GDC and IS_PRIMARY: + # GDC launch dependents hints the runtime system to launch dependent kernels. + tl.extra.cuda.gdc_launch_dependents() + + # ================================================================= secure + + # accumulator + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # ================================================================= secure + + # GDC wait waits for ALL programs in the prior kernel to complete + # before continuing. + if USE_GDC and not IS_PRIMARY: + tl.extra.cuda.gdc_wait() + + for k in range(0, grid_k): + k_remaining = K - k * (BLOCK_SIZE_K * SPLIT_K) + # pre-fetch lora weight + b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & (offs_k[None, :] < k_remaining), + other=0.0, + ) + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak + b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0) + accumulator = accumulator * moe_weight[:, None] + accumulator = accumulator.to(c_ptr.dtype.element_ty) + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = cur_c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + + if SPLIT_K == 1: + tl.store(c_ptrs, accumulator, mask=c_mask) + else: + tl.atomic_add(c_ptrs, accumulator, mask=c_mask, sem="relaxed") + + +@torch.inference_mode() +def _fused_moe_lora_shrink( + a_intermediate_cache1: torch.Tensor, + # (num_slices, num_tokens, top_k_num, max_lora_rank) + qcurr_hidden_states: torch.Tensor, # (num_tokens, K,) + lora_a_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, max_lora_rank, K,),...] + topk_weights: torch.Tensor, # (num_tokens, top_k_num) + sorted_token_ids: torch.Tensor, # (max_loras, _) + expert_ids: torch.Tensor, # (max_loras, _ ,) + num_tokens_post_padded: torch.Tensor, # (max_loras, ) + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + ## adding for kernel + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + num_warps: int, + num_stages: int, + split_k: int, + mul_routed_weight: bool = False, +) -> None: + w1_lora_a_stacked = lora_a_stacked[0] + + use_gdc = is_sm90_supported() or is_blackwell_supported() + shrink_config = { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, + "SPLIT_K": split_k, + "USE_GDC": use_gdc, + "launch_pdl": use_gdc, # triton kernel metadata + } + + b_ptr = _get_ptr(lora_a_stacked, device) + + grid = lambda META: ( + split_k + * triton.cdiv(EM, META["BLOCK_SIZE_M"]) + * triton.cdiv(N, META["BLOCK_SIZE_N"]), + len(lora_a_stacked), + lora_a_stacked[0].shape[0], + ) + _fused_moe_lora_kernel[grid]( + qcurr_hidden_states, + b_ptr, + a_intermediate_cache1, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + N, + K, + EM, + num_tokens, + num_experts, + lora_ids, + adapter_enabled, + qcurr_hidden_states.stride(0), + qcurr_hidden_states.stride(1), + w1_lora_a_stacked.stride(0), + w1_lora_a_stacked.stride(1), + w1_lora_a_stacked.stride(3), + w1_lora_a_stacked.stride(2), + a_intermediate_cache1.stride(2), + a_intermediate_cache1.stride(3), + sorted_token_ids.stride(0), + expert_ids.stride(0), + slice_a_size=qcurr_hidden_states.numel(), + slice_c_size=a_intermediate_cache1.numel() // num_slices, + num_slice_a=1, + num_slice_c=num_slices, + top_k=1 if mul_routed_weight else top_k_num, + MUL_ROUTED_WEIGHT=False, + IS_PRIMARY=True, + **shrink_config, + ) + + +@torch.inference_mode() +def _fused_moe_lora_expand( + output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) + a_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, max_lora_rank) + b_intermediate_cache1: torch.Tensor, # (num_slices, M, top_k_num, output_dim_size) + lora_b_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, max_lora_rank, K,),...] + topk_weights: torch.Tensor, # (num_tokens, top_k_num) + sorted_token_ids: torch.Tensor, # (max_loras, _) + expert_ids: torch.Tensor, # (max_loras, _ ,) + num_tokens_post_padded: torch.Tensor, # (max_loras, ) + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + ## adding for kernel + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, + max_lora_rank: int, + w1_output_dim_size: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + num_warps: int, + num_stages: int, + split_k: int, + mul_routed_weight: bool = False, + offset: int = 0, +) -> None: + + b_ptr = _get_ptr(lora_b_stacked, device) + K = max_lora_rank + N = w1_output_dim_size + + w1_lora_b_stacked = lora_b_stacked[0] + + a_intermediate_cache1 = a_intermediate_cache1.view( + -1, a_intermediate_cache1.shape[3] + ) + + use_gdc = is_sm90_supported() or is_blackwell_supported() + expand_config = { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + "num_warps": num_warps, + "num_stages": num_stages, + "SPLIT_K": split_k, # Set split_k = 1 for expand calls + "USE_GDC": use_gdc, + "launch_pdl": use_gdc, # triton kernel metadata + } + + grid = lambda META: ( + triton.cdiv(EM, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + len(lora_b_stacked), + lora_b_stacked[0].shape[0], + ) + _fused_moe_lora_kernel[grid]( + a_intermediate_cache1, + b_ptr, + b_intermediate_cache1, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + N, + K, + EM, + num_tokens, + num_experts, + lora_ids, + adapter_enabled, + a_intermediate_cache1.stride(0), + a_intermediate_cache1.stride(1), + w1_lora_b_stacked.stride(0), + w1_lora_b_stacked.stride(1), + w1_lora_b_stacked.stride(3), + w1_lora_b_stacked.stride(2), + b_intermediate_cache1.stride(2), + b_intermediate_cache1.stride(3), + sorted_token_ids.stride(0), + expert_ids.stride(0), + slice_a_size=a_intermediate_cache1.numel() // num_slices, + slice_c_size=b_intermediate_cache1.numel() // num_slices, + num_slice_a=num_slices, + num_slice_c=num_slices, + top_k=1, + MUL_ROUTED_WEIGHT=mul_routed_weight, + IS_PRIMARY=False, + **expand_config, + ) + for i in range(num_slices): + output[:, :, i * N + offset : (i + 1) * N + offset] += b_intermediate_cache1[i] + + +@torch.inference_mode() +def _fused_moe_lora( + output: torch.Tensor, # (num_tokens, top_k_num, N*len(lora_a_stacked),) + qcurr_hidden_states: torch.Tensor, # (num_tokens, K,) + lora_a_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, max_lora_rank, K,),...] + lora_b_stacked: list[ + torch.Tensor + ], # [(max_loras, num_experts, N, max_lora_rank,),...] + topk_weights: torch.Tensor, # (num_tokens, top_k_num) + sorted_token_ids: torch.Tensor, # (max_loras, _) + expert_ids: torch.Tensor, # (max_loras, _ ,) + num_tokens_post_padded: torch.Tensor, # (max_loras, ) + max_lora_rank: int, + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + shrink_block_size_m: int, + shrink_block_size_n: int, + shrink_block_size_k: int, + shrink_group_size_m: int, + shrink_num_warps: int, + shrink_num_stages: int, + shrink_split_k: int, + expand_block_size_m: int, + expand_block_size_n: int, + expand_block_size_k: int, + expand_group_size_m: int, + expand_num_warps: int, + expand_num_stages: int, + expand_split_k: int, + mul_routed_weight: bool = False, + fully_sharded: bool = False, + offset: int = 0, +) -> None: + assert len(lora_a_stacked) == len(lora_b_stacked) > 0 + assert ( + sorted_token_ids.dim() + == expert_ids.dim() + == topk_weights.dim() + == qcurr_hidden_states.dim() + == 2 + ) + assert ( + sorted_token_ids.shape[0] + == expert_ids.shape[0] + == num_tokens_post_padded.shape[0] + ) + assert output.shape[0] == topk_weights.shape[0] + assert top_k_num == topk_weights.shape[1] + device = qcurr_hidden_states.device + num_slices = len(lora_a_stacked) + w1_lora_b_stacked = lora_b_stacked[0] + num_experts = lora_a_stacked[0].shape[1] + N = max_lora_rank + M = topk_weights.shape[0] + EM = sorted_token_ids.shape[1] + K = qcurr_hidden_states.shape[1] + num_tokens = M * top_k_num + w1_output_dim_size = w1_lora_b_stacked.shape[2] + + a_intermediate_cache1 = torch.zeros( + (num_slices, M, top_k_num, max_lora_rank), + dtype=output.dtype, + device=device, + ) + + b_intermediate_cache1 = torch.zeros( + (num_slices, M, top_k_num, w1_output_dim_size), + dtype=output.dtype, + device=device, + ) + + _fused_moe_lora_shrink( + a_intermediate_cache1, + qcurr_hidden_states, + lora_a_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + top_k_num, + lora_ids, + adapter_enabled, + ## adding for kernel + device, + N, + M, + EM, + K, + num_tokens, + num_experts, + num_slices, + shrink_block_size_m, + shrink_block_size_n, + shrink_block_size_k, + shrink_group_size_m, + shrink_num_warps, + shrink_num_stages, + shrink_split_k, + mul_routed_weight=False, + ) + + if fully_sharded: + if max_lora_rank == w1_lora_b_stacked.shape[-1]: + a_intermediate_cache1 = tensor_model_parallel_all_reduce( + a_intermediate_cache1 + ) + else: + a_intermediate_cache1 = tensor_model_parallel_all_gather( + a_intermediate_cache1 + ) + + # reset max_lora_rank to the full rank after allgather + max_lora_rank = a_intermediate_cache1.shape[-1] + + _fused_moe_lora_expand( + output, + a_intermediate_cache1, + b_intermediate_cache1, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + top_k_num, + lora_ids, + adapter_enabled, + ## adding for kernel + device, + N, + M, + EM, + K, + num_tokens, + num_experts, + num_slices, + max_lora_rank, + w1_output_dim_size, + expand_block_size_m, + expand_block_size_n, + expand_block_size_k, + expand_group_size_m, + expand_num_warps, + expand_num_stages, + expand_split_k, + mul_routed_weight, + offset, + ) + + +def _fused_moe_lora_fake( + output: torch.Tensor, + qcurr_hidden_states: torch.Tensor, + lora_a_stacked: list[torch.Tensor], + lora_b_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + max_lora_rank: int, + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + shrink_block_size_m: int, + shrink_block_size_n: int, + shrink_block_size_k: int, + shrink_group_size_m: int, + shrink_num_warps: int, + shrink_num_stages: int, + shrink_split_k: int, + expand_block_size_m: int, + expand_block_size_n: int, + expand_block_size_k: int, + expand_group_size_m: int, + expand_num_warps: int, + expand_num_stages: int, + expand_split_k: int, + mul_routed_weight: bool = False, + fully_sharded: bool = False, + offset: int = 0, +) -> None: + return + + +def _fused_moe_lora_shrink_fake( + a_intermediate_cache1: torch.Tensor, + qcurr_hidden_states: torch.Tensor, + lora_a_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + num_warps: int, + num_stages: int, + split_k: int, + mul_routed_weight: bool = False, +) -> None: + return + + +def _fused_moe_lora_expand_fake( + output: torch.Tensor, + a_intermediate_cache1: torch.Tensor, + b_intermediate_cache1: torch.Tensor, + lora_b_stacked: list[torch.Tensor], + topk_weights: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, + device: torch.device, + N: int, + M: int, + EM: int, + K: int, + num_tokens: int, + num_experts: int, + num_slices: int, + max_lora_rank: int, + w1_output_dim_size: int, + block_size_m: int, + block_size_n: int, + block_size_k: int, + group_size_m: int, + num_warps: int, + num_stages: int, + split_k: int, + mul_routed_weight: bool = False, + offset: int = 0, +) -> None: + return + + +# Register as SGLang custom ops following the same pattern as other ops +try: + from sglang.srt.utils.common import direct_register_custom_op + + direct_register_custom_op( + op_name="fused_moe_lora", + op_func=_fused_moe_lora, + mutates_args=["output"], + fake_impl=_fused_moe_lora_fake, + ) + + direct_register_custom_op( + op_name="fused_moe_lora_shrink", + op_func=_fused_moe_lora_shrink, + mutates_args=["a_intermediate_cache1"], + fake_impl=_fused_moe_lora_shrink_fake, + ) + + direct_register_custom_op( + op_name="fused_moe_lora_expand", + op_func=_fused_moe_lora_expand, + mutates_args=["output", "b_intermediate_cache1"], + fake_impl=_fused_moe_lora_expand_fake, + ) + + # Export through torch.ops.sglang namespace + fused_moe_lora = torch.ops.sglang.fused_moe_lora + fused_moe_lora_shrink = torch.ops.sglang.fused_moe_lora_shrink + fused_moe_lora_expand = torch.ops.sglang.fused_moe_lora_expand + +except AttributeError: + fused_moe_lora = _fused_moe_lora + fused_moe_lora_shrink = _fused_moe_lora_shrink + fused_moe_lora_expand = _fused_moe_lora_expand diff --git a/test/registered/lora/test_fused_moe_lora_kernel.py b/test/registered/lora/test_fused_moe_lora_kernel.py new file mode 100644 index 000000000000..4f8664624361 --- /dev/null +++ b/test/registered/lora/test_fused_moe_lora_kernel.py @@ -0,0 +1,380 @@ +# Temporarily adapted from https://github.com/vllm-project/vllm/blob/main/tests/lora/test_fused_moe_lora_kernel.py, will optimize in future refactor +import random + +import pytest +import torch + +# ============================================================================== +# IMPORT PREBUILT KERNEL +# ============================================================================== +from sglang.jit_kernel.moe_lora_align import moe_lora_align_block_size +from sglang.srt.lora.triton_ops import fused_moe_lora +from sglang.srt.utils import set_random_seed +from sglang.test.ci.ci_register import register_cuda_ci + +# ============================================================================== + +register_cuda_ci(est_time=120, suite="stage-b-test-large-1-gpu") + + +def round_up(x, base): + return ((x + base - 1) // base) * base + + +def CEILDIV(x, y): + return (x + y - 1) // y + + +def assign_loras_to_tokens(num_tokens: int, num_sequences: int, max_loras: int): + """ + Split `num_tokens` into `num_sequences` sequences. + Each sequence randomly selects 1 LoRA index from [0, max_loras), + and all tokens in that sequence are assigned this LoRA index. + + Args: + num_tokens (int): Total number of tokens. + num_sequences (int): Number of sequences to split the tokens into. + max_loras (int): Total number of available LoRA modules. + + Returns: + token_lora_mapping (torch.Tensor): 1D tensor of shape [num_tokens] + seg_indptr (torch.Tensor): 1D tensor of shape [num_sequences + 1] + req_to_lora (torch.Tensor): 1D tensor of shape [num_sequences] + """ + assert num_sequences > 0 and max_loras > 0 + assert num_tokens >= num_sequences, "num_tokens must be >= num_sequences" + + # Compute token distribution per sequence (distribute remainder evenly) + tokens_per_seq = num_tokens // num_sequences + remainder = num_tokens % num_sequences + + token_lora_mapping = torch.empty(num_tokens, dtype=torch.int32) + seg_indptr = [0] + req_to_lora = [] + + start = 0 + for seq_idx in range(num_sequences): + # Determine the token range for this sequence + end = start + tokens_per_seq + (1 if seq_idx < remainder else 0) + + # Randomly select one LoRA ID for this sequence + lora_id = random.randint(0, max_loras - 1) + + # Assign the same LoRA ID to all tokens in this sequence + token_lora_mapping[start:end] = lora_id + + seg_indptr.append(end) + req_to_lora.append(lora_id) + + start = end + + seg_indptr = torch.tensor(seg_indptr, dtype=torch.int32) + req_to_lora = torch.tensor(req_to_lora, dtype=torch.int32) + + return token_lora_mapping, seg_indptr, req_to_lora + + +def assign_experts_to_tokens(num_tokens: int, num_experts: int, top_k_num: int): + """ + For each token, randomly select `top_k_num` distinct experts out of `num_experts`, + and assign normalized random weights that sum to 1. + + Args: + num_tokens (int): Total number of tokens. + num_experts (int): Total number of available experts. + top_k_num (int): Number of experts to select per token. + + Returns: + expert_indices (torch.Tensor): shape [num_tokens, top_k_num], + expert index for each token. + expert_weights (torch.Tensor): shape [num_tokens, top_k_num], + normalized weights (sum = 1 per row). + """ + assert top_k_num <= num_experts, "top_k_num must be <= num_experts" + + # Randomly select top_k_num distinct experts for each token + expert_indices = torch.empty((num_tokens, top_k_num), dtype=torch.int32) + for i in range(num_tokens): + # Randomly choose unique expert indices + selected = torch.randperm(num_experts)[:top_k_num] + expert_indices[i] = selected + + # Generate random weights and normalize along dim=1 + expert_weights = torch.rand((num_tokens, top_k_num), dtype=torch.float32) + expert_weights = expert_weights / expert_weights.sum(dim=1, keepdim=True) + + return expert_indices, expert_weights + + +def sample_data( + num_tokens: int, + num_sequences: int, + max_loras: int, + num_experts: int, + top_k_num: int, +): + topk_ids, topk_weights = assign_experts_to_tokens( + num_tokens, num_experts, top_k_num + ) + token_lora_mapping, seg_indptr, req_to_lora = assign_loras_to_tokens( + num_tokens, num_sequences, max_loras + ) + return topk_ids, topk_weights, token_lora_mapping, seg_indptr, req_to_lora + + +def use_fused_moe_lora_kernel( + topk_ids, + topk_weights, + seg_indptr, + req_to_lora, + max_lora_rank, + top_k_num, + lora_a_stacked, + lora_b_stacked, + hidden_states, + output, + max_loras, + num_experts, + block_size, + mul_routed_weight, + fully_sharded=False, + offset=0, +): + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + max_num_tokens_padded = round_up(max_num_tokens_padded, block_size) + max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size) + + # Important: Ensure output tensors are on the same device as inputs + device = topk_ids.device + + # init output tensors + sorted_token_ids = torch.empty( + (max_loras * max_num_tokens_padded,), dtype=torch.int32, device=device + ) + expert_ids = torch.empty( + (max_loras * max_num_m_blocks,), dtype=torch.int32, device=device + ) + num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32, device=device) + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32, device=device) + lora_ids = torch.arange(max_loras, dtype=torch.int32, device=device) + + # call kernel + moe_lora_align_block_size( + topk_ids, + seg_indptr, + req_to_lora, + num_experts, + block_size, + max_loras, + max_num_tokens_padded, + max_num_m_blocks, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + adapter_enabled, + lora_ids, + None, # maybe_expert_map + ) + + config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "NUM_WARPS": 4, + "NUM_STAGES": 3, + "SPLIT_K": 1, + } + + expert_ids = expert_ids.view(max_loras, -1) + sorted_token_ids = sorted_token_ids.view(max_loras, -1) + + fused_moe_lora( + output, + hidden_states, + lora_a_stacked, + lora_b_stacked, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + max_lora_rank, + top_k_num, + lora_ids, + adapter_enabled, + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], + config["GROUP_SIZE_M"], + config["NUM_WARPS"], + config["NUM_STAGES"], + config["SPLIT_K"], + config["BLOCK_SIZE_M"], + config["BLOCK_SIZE_N"], + config["BLOCK_SIZE_K"], + config["GROUP_SIZE_M"], + config["NUM_WARPS"], + config["NUM_STAGES"], + config["SPLIT_K"], + mul_routed_weight, + fully_sharded=fully_sharded, + offset=offset, + ) + + +def use_torch( + hidden_states, + token_lora_mapping, + topk_ids, + topk_weights, + lora_a_stacked, + lora_b_stacked, + top_k_num, + mul_routed_weight, +): + outputs = [] + + orig_dtype = hidden_states.dtype + for i in range(hidden_states.shape[0]): + lora_idx = token_lora_mapping[i] + expert_ids = topk_ids[i] + expert_weights = topk_weights[i] + + lora_a = lora_a_stacked[0][lora_idx][expert_ids] + lora_b = lora_b_stacked[0][lora_idx][expert_ids] + + h_f32 = hidden_states[i].to(torch.float32) + la_f32 = lora_a.to(torch.float32) + lb_f32 = lora_b.to(torch.float32) + + if mul_routed_weight: + tensors = [ + ((h_f32 @ la_f32[x].T @ lb_f32[x].T) * expert_weights[x]).to(orig_dtype) + for x in range(top_k_num) + ] + else: + tensors = [ + (h_f32 @ la_f32[x].T @ lb_f32[x].T).to(orig_dtype) + for x in range(top_k_num) + ] + outputs.append(torch.stack(tensors, dim=0)) + return torch.stack(outputs, dim=0) + + +DTYPES = [torch.float32, torch.float16, torch.bfloat16] +DEVICES = [f"cuda:{0}"] +SEED = [42] + + +@pytest.mark.parametrize("mul_routed_weight", [False, True]) +@pytest.mark.parametrize("num_tokens", [100]) +@pytest.mark.parametrize("top_k_num", [6, 12]) +@pytest.mark.parametrize("num_experts", [64]) +@pytest.mark.parametrize("max_loras", [4, 6, 16]) +@pytest.mark.parametrize("N", [1408]) +@pytest.mark.parametrize("K", [2048]) +@pytest.mark.parametrize("max_lora_rank", [16, 32, 64]) +@pytest.mark.parametrize("block_size", [16]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("device", DEVICES) +@pytest.mark.parametrize("seed", SEED) +def test_fused_moe_lora_kernel( + mul_routed_weight, + num_tokens, + top_k_num, + num_experts, + max_loras, + N, + K, + max_lora_rank, + block_size, + dtype, + device, + seed, +): + torch.set_default_device(device) + set_random_seed(seed) + # the number of randomly generated sentences. + num_sequences = 10 + # generate data + topk_ids, topk_weights, token_lora_mapping, seg_indptr, req_to_lora = sample_data( + num_tokens, num_sequences, max_loras, num_experts, top_k_num + ) + + # Ensure generated data is on the correct device + topk_ids = topk_ids.to(device) + topk_weights = topk_weights.to(device) + token_lora_mapping = token_lora_mapping.to(device) + seg_indptr = seg_indptr.to(device) + req_to_lora = req_to_lora.to(device) + + # init lora weights + lora_a_stacked = [ + torch.rand( + ( + max_loras, + num_experts, + max_lora_rank, + K, + ), + dtype=dtype, + device=device, + ) + ] + lora_b_stacked = [ + torch.rand( + ( + max_loras, + num_experts, + N, + max_lora_rank, + ), + dtype=dtype, + device=device, + ) + ] + hidden_states = torch.rand( + ( + num_tokens, + K, + ), + dtype=dtype, + device=device, + ) + + # fused_moe_lora_kernel output + output = torch.zeros((num_tokens, top_k_num, N), dtype=dtype, device=device) + + use_fused_moe_lora_kernel( + topk_ids, + topk_weights, + seg_indptr, + req_to_lora, + max_lora_rank, + top_k_num, + lora_a_stacked, + lora_b_stacked, + hidden_states, + output, + max_loras, + num_experts, + block_size, + mul_routed_weight=mul_routed_weight, + ) + # pytorch output + output2 = use_torch( + hidden_states, + token_lora_mapping, + topk_ids, + topk_weights, + lora_a_stacked, + lora_b_stacked, + top_k_num, + mul_routed_weight=mul_routed_weight, + ) + + torch.testing.assert_close(output, output2, atol=1e-2, rtol=1e-2) + + +if __name__ == "__main__": + pytest.main([__file__])