diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 35939b9791fe..83582e4b99e1 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -62,6 +62,7 @@ def _fused_moe_lora_kernel( num_experts, lora_ids, adapter_enabled, + max_loras, # <<< PR2: rename, used for masks when grid axis-2 != max_loras # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is # how much to increase `a_ptr` by to get the element one row down @@ -83,6 +84,7 @@ def _fused_moe_lora_kernel( num_slice_c: tl.constexpr, top_k: tl.constexpr, MUL_ROUTED_WEIGHT: tl.constexpr, + USE_B_L2_CACHE: tl.constexpr, # new, enable .ca load for B BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, @@ -104,10 +106,13 @@ def _fused_moe_lora_kernel( if moe_enabled == 0: # Early exit for the no moe lora case. return - # The grid size on axis 2 is (max_loras + 1) to handle the no-lora case - # (lora_id == -1), but sorted_token_ids and expert_ids are allocated with - # shape (max_loras, ...). Use (num_programs - 1) for correct bounds checking. - max_loras = tl.num_programs(axis=2) - 1 + # The grid's axis-2 dimension is max_loras + 1 to accommodate the -1 sentinel. + # This guard ensures we don't access sorted_token_ids / expert_ids / + # num_tokens_post_padded beyond their allocated bounds if an invalid + # lora_id somehow appears. Although the caller should pass correct + # max_loras, defensive programming prevents accidental out-of-bounds. + if lora_id >= max_loras: + return grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) # calculate pid_m,pid_n @@ -136,10 +141,11 @@ def _fused_moe_lora_kernel( cur_b_ptr = tl.load(b_ptr + slice_id).to(tl.pointer_type(c_ptr.dtype.element_ty)) cur_c_ptr = c_ptr + (slice_id % num_slice_c) * slice_c_size - offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N + # remove modulo wrap-around + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int32) offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) - offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int32) token_ind = stride_tl * lora_id + offs_token_id offs_token = tl.load( sorted_token_ids_ptr + token_ind, @@ -176,7 +182,13 @@ def _fused_moe_lora_kernel( # GDC wait waits for ALL programs in the prior kernel to complete # before continuing. # pre-fetch lora weight - b = tl.load(b_ptrs, mask=offs_k[:, None] < k_remaining, other=0.0) + # add (offs_bn < N) mask; optional .ca for B + b_mask = (offs_k[:, None] < k_remaining) & (offs_bn[None, :] < N) + if USE_B_L2_CACHE: + b = tl.load(b_ptrs, mask=b_mask, other=0.0, cache_modifier=".ca") + else: + b = tl.load(b_ptrs, mask=b_mask, other=0.0) + if USE_GDC and not IS_PRIMARY: tl.extra.cuda.gdc_wait() a = tl.load( @@ -276,6 +288,7 @@ def _fused_moe_lora_shrink( num_experts, lora_ids, adapter_enabled, + lora_a_stacked[0].shape[0], qcurr_hidden_states.stride(0), qcurr_hidden_states.stride(1), w1_lora_a_stacked.stride(0), @@ -292,6 +305,7 @@ def _fused_moe_lora_shrink( num_slice_c=num_slices, top_k=1 if mul_routed_weight else top_k_num, MUL_ROUTED_WEIGHT=False, + USE_B_L2_CACHE=True, # new IS_PRIMARY=True, **shrink_config, ) @@ -377,6 +391,7 @@ def _fused_moe_lora_expand( num_experts, lora_ids, adapter_enabled, + lora_b_stacked[0].shape[0], a_intermediate_cache1.stride(0), a_intermediate_cache1.stride(1), w1_lora_b_stacked.stride(0), @@ -393,6 +408,7 @@ def _fused_moe_lora_expand( num_slice_c=num_slices, top_k=1, MUL_ROUTED_WEIGHT=mul_routed_weight, + USE_B_L2_CACHE=True, # new IS_PRIMARY=False, **expand_config, )