diff --git a/vllm/lora/ops/triton_ops/kernel_utils.py b/vllm/lora/ops/triton_ops/kernel_utils.py index ebfffc17ae87..c6c2a02fdeb5 100644 --- a/vllm/lora/ops/triton_ops/kernel_utils.py +++ b/vllm/lora/ops/triton_ops/kernel_utils.py @@ -23,6 +23,7 @@ def mm_k( CAST_TYPE: tl.constexpr, b_dtype: tl.constexpr, USE_GDC: tl.constexpr, + base_k, ): """ Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of @@ -47,32 +48,62 @@ def mm_k( matrix dtype. b_dtype: datatype of the B matrix USE_GDC: Whether to use PDL. True indicates use. + base_k: Base offset along K dimension for current SPLIT_K group """ accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)): + + # Step size along K for each iteration + STEP_K = BLOCK_K * SPLIT_K + + # Total number of iterations (compile-time constant) + num_iters = tl.cdiv(K, STEP_K) + + for k in range(num_iters): + # Current iteration's global K offset + iter_k = k * STEP_K + base_k + + # Check if this iteration is completely valid (no masking needed) + block_end = iter_k + BLOCK_K + if EVEN_K: - # pre-fetech lora weight + # K is divisible by BLOCK_K, no masking ever needed + # pre-fetch lora weight tiled_b = tl.load(b_ptr) if USE_GDC: tl.extra.cuda.gdc_wait() tiled_a = tl.load(a_ptr) + if CAST_TYPE: + tiled_a = tiled_a.to(b_dtype) + accumulator += tl.dot(tiled_a, tiled_b) else: - tiled_b = tl.load( - b_ptr, mask=offset_k[:, None] < K - k * (BLOCK_K * SPLIT_K), other=0 - ) - if USE_GDC: - tl.extra.cuda.gdc_wait() - tiled_a = tl.load( - a_ptr, mask=offset_k[None, :] < K - k * (BLOCK_K * SPLIT_K), other=0 - ) - if CAST_TYPE: - tiled_a = tiled_a.to(b_dtype) - accumulator += tl.dot( - tiled_a, - tiled_b, - ) - a_ptr += BLOCK_K * SPLIT_K * ak_stride - b_ptr += BLOCK_K * SPLIT_K * bk_stride + # Check if we need element-wise masking + if iter_k >= K: + # Entire block out of range, skip + pass + elif block_end <= K: + # Entire block in range, no masking needed (fast path) + tiled_b = tl.load(b_ptr) + if USE_GDC: + tl.extra.cuda.gdc_wait() + tiled_a = tl.load(a_ptr) + if CAST_TYPE: + tiled_a = tiled_a.to(b_dtype) + accumulator += tl.dot(tiled_a, tiled_b) + else: + # Partial block, need masking (only last iteration) + k_offsets = tl.arange(0, BLOCK_K) + mask = iter_k + k_offsets < K + tiled_b = tl.load(b_ptr, mask=mask[:, None], other=0.0) + if USE_GDC: + tl.extra.cuda.gdc_wait() + tiled_a = tl.load(a_ptr, mask=mask[None, :], other=0.0) + if CAST_TYPE: + tiled_a = tiled_a.to(b_dtype) + accumulator += tl.dot(tiled_a, tiled_b) + + a_ptr += STEP_K * ak_stride + b_ptr += STEP_K * bk_stride + return accumulator @@ -178,6 +209,7 @@ def do_expand_kernel( CAST_TYPE, cur_lora_ptr.dtype.element_ty, USE_GDC, + base_k=0, ) tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) @@ -284,6 +316,7 @@ def do_shrink_kernel( False, cur_lora_ptr.dtype.element_ty, False, # USE_GDC is always False in shrink kernel + base_k=pid_sk * BLOCK_K, ) # GDC launch dependents hints the runtime system to launch dependent kernels. if USE_GDC: