@@ -22,6 +22,7 @@ def mm_k(
2222 SPLIT_K : tl .constexpr ,
2323 CAST_TYPE : tl .constexpr ,
2424 b_dtype : tl .constexpr ,
25+ base_k ,
2526):
2627 """
2728 Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
@@ -45,27 +46,57 @@ def mm_k(
4546 CAST_TYPE: if True, cast the values from the A matrix to the B
4647 matrix dtype.
4748 b_dtype: datatype of the B matrix
49+ base_k: Base offset along K dimension for current SPLIT_K group
4850 """
4951 accumulator = tl .zeros ((BLOCK_M , BLOCK_N ), dtype = tl .float32 )
50- for k in range (tl .cdiv (K , BLOCK_K * SPLIT_K )):
52+
53+ # Step size along K for each iteration
54+ STEP_K = BLOCK_K * SPLIT_K
55+
56+ # Total number of iterations (compile-time constant)
57+ num_iters = tl .cdiv (K , STEP_K )
58+
59+ for k in range (num_iters ):
60+ # Current iteration's global K offset
61+ iter_k = k * STEP_K + base_k
62+
63+ # Check if this iteration is completely valid (no masking needed)
64+ block_end = iter_k + BLOCK_K
65+
5166 if EVEN_K :
52- tiled_a = tl .load (a_ptr )
53- tiled_b = tl .load (b_ptr )
67+ # K is divisible by BLOCK_K, no masking ever needed
68+ # But skip if entire block is out of range
69+ if iter_k < K :
70+ tiled_a = tl .load (a_ptr )
71+ tiled_b = tl .load (b_ptr )
72+ if CAST_TYPE :
73+ tiled_a = tiled_a .to (b_dtype )
74+ accumulator += tl .dot (tiled_a , tiled_b )
5475 else :
55- tiled_a = tl .load (
56- a_ptr , mask = offset_k [None , :] < K - k * (BLOCK_K * SPLIT_K ), other = 0
57- )
58- tiled_b = tl .load (
59- b_ptr , mask = offset_k [:, None ] < K - k * (BLOCK_K * SPLIT_K ), other = 0
60- )
61- if CAST_TYPE :
62- tiled_a = tiled_a .to (b_dtype )
63- accumulator += tl .dot (
64- tiled_a ,
65- tiled_b ,
66- )
67- a_ptr += BLOCK_K * SPLIT_K * ak_stride
68- b_ptr += BLOCK_K * SPLIT_K * bk_stride
76+ # Check if we need element-wise masking
77+ if iter_k >= K :
78+ # Entire block out of range, skip
79+ pass
80+ elif block_end <= K :
81+ # Entire block in range, no masking needed (fast path)
82+ tiled_a = tl .load (a_ptr )
83+ tiled_b = tl .load (b_ptr )
84+ if CAST_TYPE :
85+ tiled_a = tiled_a .to (b_dtype )
86+ accumulator += tl .dot (tiled_a , tiled_b )
87+ else :
88+ # Partial block, need masking (only last iteration)
89+ k_offsets = tl .arange (0 , BLOCK_K )
90+ mask = iter_k + k_offsets < K
91+ tiled_a = tl .load (a_ptr , mask = mask [None , :], other = 0.0 )
92+ tiled_b = tl .load (b_ptr , mask = mask [:, None ], other = 0.0 )
93+ if CAST_TYPE :
94+ tiled_a = tiled_a .to (b_dtype )
95+ accumulator += tl .dot (tiled_a , tiled_b )
96+
97+ a_ptr += STEP_K * ak_stride
98+ b_ptr += STEP_K * bk_stride
99+
69100 return accumulator
70101
71102
@@ -168,6 +199,7 @@ def do_expand_kernel(
168199 SPLIT_K ,
169200 CAST_TYPE ,
170201 cur_lora_ptr .dtype .element_ty ,
202+ base_k = 0 ,
171203 )
172204
173205 tiled_c = accumulator .to (cur_lora_ptr .dtype .element_ty )
@@ -272,6 +304,7 @@ def do_shrink_kernel(
272304 SPLIT_K ,
273305 False ,
274306 cur_lora_ptr .dtype .element_ty ,
307+ base_k = pid_sk * BLOCK_K ,
275308 )
276309
277310 # Identify the C output pointers to store the results of the accumulator.
0 commit comments