Skip to content

Commit 21cd263

Browse files
committed
[Kernel] Optimization of the mm_k operator.
1 parent 11fd69d commit 21cd263

File tree

1 file changed

+50
-17
lines changed

1 file changed

+50
-17
lines changed

vllm/lora/ops/triton_ops/kernel_utils.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def mm_k(
2222
SPLIT_K: tl.constexpr,
2323
CAST_TYPE: tl.constexpr,
2424
b_dtype: tl.constexpr,
25+
base_k,
2526
):
2627
"""
2728
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
@@ -45,27 +46,57 @@ def mm_k(
4546
CAST_TYPE: if True, cast the values from the A matrix to the B
4647
matrix dtype.
4748
b_dtype: datatype of the B matrix
49+
base_k: Base offset along K dimension for current SPLIT_K group
4850
"""
4951
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
50-
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):
52+
53+
# Step size along K for each iteration
54+
STEP_K = BLOCK_K * SPLIT_K
55+
56+
# Total number of iterations (compile-time constant)
57+
num_iters = tl.cdiv(K, STEP_K)
58+
59+
for k in range(num_iters):
60+
# Current iteration's global K offset
61+
iter_k = k * STEP_K + base_k
62+
63+
# Check if this iteration is completely valid (no masking needed)
64+
block_end = iter_k + BLOCK_K
65+
5166
if EVEN_K:
52-
tiled_a = tl.load(a_ptr)
53-
tiled_b = tl.load(b_ptr)
67+
# K is divisible by BLOCK_K, no masking ever needed
68+
# But skip if entire block is out of range
69+
if iter_k < K:
70+
tiled_a = tl.load(a_ptr)
71+
tiled_b = tl.load(b_ptr)
72+
if CAST_TYPE:
73+
tiled_a = tiled_a.to(b_dtype)
74+
accumulator += tl.dot(tiled_a, tiled_b)
5475
else:
55-
tiled_a = tl.load(
56-
a_ptr, mask=offset_k[None, :] < K - k * (BLOCK_K * SPLIT_K), other=0
57-
)
58-
tiled_b = tl.load(
59-
b_ptr, mask=offset_k[:, None] < K - k * (BLOCK_K * SPLIT_K), other=0
60-
)
61-
if CAST_TYPE:
62-
tiled_a = tiled_a.to(b_dtype)
63-
accumulator += tl.dot(
64-
tiled_a,
65-
tiled_b,
66-
)
67-
a_ptr += BLOCK_K * SPLIT_K * ak_stride
68-
b_ptr += BLOCK_K * SPLIT_K * bk_stride
76+
# Check if we need element-wise masking
77+
if iter_k >= K:
78+
# Entire block out of range, skip
79+
pass
80+
elif block_end <= K:
81+
# Entire block in range, no masking needed (fast path)
82+
tiled_a = tl.load(a_ptr)
83+
tiled_b = tl.load(b_ptr)
84+
if CAST_TYPE:
85+
tiled_a = tiled_a.to(b_dtype)
86+
accumulator += tl.dot(tiled_a, tiled_b)
87+
else:
88+
# Partial block, need masking (only last iteration)
89+
k_offsets = tl.arange(0, BLOCK_K)
90+
mask = iter_k + k_offsets < K
91+
tiled_a = tl.load(a_ptr, mask=mask[None, :], other=0.0)
92+
tiled_b = tl.load(b_ptr, mask=mask[:, None], other=0.0)
93+
if CAST_TYPE:
94+
tiled_a = tiled_a.to(b_dtype)
95+
accumulator += tl.dot(tiled_a, tiled_b)
96+
97+
a_ptr += STEP_K * ak_stride
98+
b_ptr += STEP_K * bk_stride
99+
69100
return accumulator
70101

71102

@@ -168,6 +199,7 @@ def do_expand_kernel(
168199
SPLIT_K,
169200
CAST_TYPE,
170201
cur_lora_ptr.dtype.element_ty,
202+
base_k=0,
171203
)
172204

173205
tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
@@ -272,6 +304,7 @@ def do_shrink_kernel(
272304
SPLIT_K,
273305
False,
274306
cur_lora_ptr.dtype.element_ty,
307+
base_k=pid_sk * BLOCK_K,
275308
)
276309

277310
# Identify the C output pointers to store the results of the accumulator.

0 commit comments

Comments
 (0)