Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 51 additions & 18 deletions vllm/lora/ops/triton_ops/kernel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def mm_k(
CAST_TYPE: tl.constexpr,
b_dtype: tl.constexpr,
USE_GDC: tl.constexpr,
base_k,
):
"""
Given a_ptr and b_ptr, that identify the rows of A (m x k) and columns of
Expand All @@ -47,32 +48,62 @@ def mm_k(
matrix dtype.
b_dtype: datatype of the B matrix
USE_GDC: Whether to use PDL. True indicates use.
base_k: Base offset along K dimension for current SPLIT_K group
"""
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(tl.cdiv(K, BLOCK_K * SPLIT_K)):

# Step size along K for each iteration
STEP_K = BLOCK_K * SPLIT_K

# Total number of iterations (compile-time constant)
num_iters = tl.cdiv(K, STEP_K)

for k in range(num_iters):
# Current iteration's global K offset
iter_k = k * STEP_K + base_k

# Check if this iteration is completely valid (no masking needed)
block_end = iter_k + BLOCK_K

if EVEN_K:
# pre-fetech lora weight
# K is divisible by BLOCK_K, no masking ever needed
# pre-fetch lora weight
tiled_b = tl.load(b_ptr)
if USE_GDC:
tl.extra.cuda.gdc_wait()
tiled_a = tl.load(a_ptr)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(tiled_a, tiled_b)
else:
tiled_b = tl.load(
b_ptr, mask=offset_k[:, None] < K - k * (BLOCK_K * SPLIT_K), other=0
)
if USE_GDC:
tl.extra.cuda.gdc_wait()
tiled_a = tl.load(
a_ptr, mask=offset_k[None, :] < K - k * (BLOCK_K * SPLIT_K), other=0
)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(
tiled_a,
tiled_b,
)
a_ptr += BLOCK_K * SPLIT_K * ak_stride
b_ptr += BLOCK_K * SPLIT_K * bk_stride
# Check if we need element-wise masking
if iter_k >= K:
# Entire block out of range, skip
pass
elif block_end <= K:
# Entire block in range, no masking needed (fast path)
tiled_b = tl.load(b_ptr)
if USE_GDC:
tl.extra.cuda.gdc_wait()
tiled_a = tl.load(a_ptr)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(tiled_a, tiled_b)
else:
# Partial block, need masking (only last iteration)
k_offsets = tl.arange(0, BLOCK_K)
mask = iter_k + k_offsets < K
tiled_b = tl.load(b_ptr, mask=mask[:, None], other=0.0)
if USE_GDC:
tl.extra.cuda.gdc_wait()
tiled_a = tl.load(a_ptr, mask=mask[None, :], other=0.0)
if CAST_TYPE:
tiled_a = tiled_a.to(b_dtype)
accumulator += tl.dot(tiled_a, tiled_b)

a_ptr += STEP_K * ak_stride
b_ptr += STEP_K * bk_stride

return accumulator


Expand Down Expand Up @@ -178,6 +209,7 @@ def do_expand_kernel(
CAST_TYPE,
cur_lora_ptr.dtype.element_ty,
USE_GDC,
base_k=0,
)

tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty)
Expand Down Expand Up @@ -284,6 +316,7 @@ def do_shrink_kernel(
False,
cur_lora_ptr.dtype.element_ty,
False, # USE_GDC is always False in shrink kernel
base_k=pid_sk * BLOCK_K,
)
# GDC launch dependents hints the runtime system to launch dependent kernels.
if USE_GDC:
Expand Down