Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 139 additions & 91 deletions vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,16 @@ def __init__(
self.quant_block = quant_block
self.token_stride = token_stride
self.scale_dim = scale_dim
self.num_warps = head_size // quant_block
self.elems_per_lane = 8
self.copy_elems = 4
self.copy_chunks = self.elems_per_lane // self.copy_elems
self.lanes_per_group = quant_block // self.elems_per_lane
self.groups_per_warp = 32 // self.lanes_per_group
self.scale_reduce_steps = self.lanes_per_group.bit_length() - 1
self.scale_reduce_offset = self.lanes_per_group // 2
self.num_warps = (head_size // quant_block) // self.groups_per_warp
self.nope_blocks = self.nope_dim // quant_block
self.tb_size = head_size // 2
self.tb_size = self.num_warps * 32
self.compress_ratio = compress_ratio
self.overlap = overlap
self.window = (1 + int(overlap)) * compress_ratio
Expand Down Expand Up @@ -156,8 +163,9 @@ def kernel(
tid, _, _ = cute.arch.thread_idx()
warp_id = cute.arch.make_warp_uniform(tid // 32)
lane_id = tid % 32
elem0 = tid * 2
elem1 = elem0 + 1
group_lane = lane_id % self.lanes_per_group
group_idx = warp_id * self.groups_per_warp + lane_id // self.lanes_per_group
elem_base = group_idx * self.quant_block + group_lane * self.elems_per_lane

slot_id = slot_mapping[token_idx]
has_position = token_idx < positions.shape[0]
Expand Down Expand Up @@ -201,12 +209,24 @@ def kernel(
s_block_numbers[row] = block_number_i32
cute.arch.sync_threads()

max0 = -Float32.inf
max1 = -Float32.inf
sum0 = Float32(0.0)
sum1 = Float32(0.0)
product0 = Float32(0.0)
product1 = Float32(0.0)
local_max = cute.make_rmem_tensor((self.elems_per_lane,), Float32)
local_sum = cute.make_rmem_tensor((self.elems_per_lane,), Float32)
local_product = cute.make_rmem_tensor((self.elems_per_lane,), Float32)

for e in cutlass.range_constexpr(self.elems_per_lane):
local_max[e] = -Float32.inf
local_sum[e] = Float32(0.0)
local_product[e] = Float32(0.0)

cp_f32x4 = cute.make_copy_atom(
cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128
)
copy_layout = cute.make_layout(
(self.copy_chunks, self.copy_elems),
stride=(self.copy_elems, 1),
)
kv_vals = cute.make_rmem_tensor(copy_layout, Float32)
score_vals = cute.make_rmem_tensor(copy_layout, Float32)

for row in cutlass.range_constexpr(self.window):
pos = start + Int64(row)
Expand All @@ -215,46 +235,51 @@ def kernel(
block_offset = pos - block_index * block_size
block_number = s_block_numbers[row].to(Int64)
head_offset = Int64((row // self.compress_ratio) * self.head_dim)
row_base = (
block_number * state_cache.stride[0]
+ block_offset * state_cache.stride[1]
+ head_offset
)

score0 = state_cache.iterator[
row_base + Int64(self.state_width) + elem0.to(Int64)
]
kv0 = state_cache.iterator[row_base + elem0.to(Int64)]
new_max0 = cute.arch.fmax(max0, score0)
old_scale0 = cute.math.exp2(
(max0 - new_max0) * Float32(self.rcp_ln2), fastmath=True
)
new_scale0 = cute.math.exp2(
(score0 - new_max0) * Float32(self.rcp_ln2), fastmath=True
)
sum0 = sum0 * old_scale0 + new_scale0
product0 = product0 * old_scale0 + kv0 * new_scale0
max0 = new_max0

score1 = state_cache.iterator[
row_base + Int64(self.state_width) + elem1.to(Int64)
]
kv1 = state_cache.iterator[row_base + elem1.to(Int64)]
new_max1 = cute.arch.fmax(max1, score1)
old_scale1 = cute.math.exp2(
(max1 - new_max1) * Float32(self.rcp_ln2), fastmath=True
)
new_scale1 = cute.math.exp2(
(score1 - new_max1) * Float32(self.rcp_ln2), fastmath=True
)
sum1 = sum1 * old_scale1 + new_scale1
product1 = product1 * old_scale1 + kv1 * new_scale1
max1 = new_max1
row_tensor = state_cache[block_number, block_offset, None]
for chunk in cutlass.range_constexpr(self.copy_chunks):
copy_elem = const_expr(chunk * self.copy_elems)
col_tile = (
head_offset + (elem_base + Int32(copy_elem)).to(Int64)
) // Int64(self.copy_elems)
kv_src = cute.local_tile(
row_tensor,
tiler=(self.copy_elems,),
coord=(col_tile,),
)
score_src = cute.local_tile(
row_tensor,
tiler=(self.copy_elems,),
coord=(
col_tile + Int64(self.state_width // self.copy_elems),
),
)
cute.copy(cp_f32x4, kv_src, kv_vals[chunk, None])
cute.copy(cp_f32x4, score_src, score_vals[chunk, None])

for e in cutlass.range_constexpr(self.elems_per_lane):
chunk = const_expr(e // self.copy_elems)
copy_elem = const_expr(e % self.copy_elems)
score = score_vals[chunk, copy_elem]
kv = kv_vals[chunk, copy_elem]
new_max = cute.arch.fmax(local_max[e], score)
old_scale = cute.math.exp2(
(local_max[e] - new_max) * Float32(self.rcp_ln2),
fastmath=True,
)
new_scale = cute.math.exp2(
(score - new_max) * Float32(self.rcp_ln2),
fastmath=True,
)
local_sum[e] = local_sum[e] * old_scale + new_scale
local_product[e] = local_product[e] * old_scale + kv * new_scale
local_max[e] = new_max

x0 = product0 / sum0
x1 = product1 / sum1
x = cute.make_rmem_tensor((self.elems_per_lane,), Float32)
local_sumsq = Float32(0.0)
for e in cutlass.range_constexpr(self.elems_per_lane):
x[e] = local_product[e] / local_sum[e]
local_sumsq += x[e] * x[e]

local_sumsq = x0 * x0 + x1 * x1
warp_sum = local_sumsq
for step in cutlass.range_constexpr(5):
offset = const_expr(16 >> step)
Expand All @@ -273,8 +298,9 @@ def kernel(
cute.arch.sync_threads()

rrms = rrms_shared[0]
x0 = x0 * rrms * rms_norm_weight[elem0].to(Float32)
x1 = x1 * rrms * rms_norm_weight[elem1].to(Float32)
for e in cutlass.range_constexpr(self.elems_per_lane):
elem = elem_base + e
x[e] = x[e] * rrms * rms_norm_weight[elem].to(Float32)

k_cache_u16 = cute.recast_tensor(k_cache, Uint16)
k_cache_u32 = cute.recast_tensor(k_cache, Uint32)
Expand All @@ -287,31 +313,53 @@ def kernel(
+ kv_offset * Int64(self.scale_dim)
)

if warp_id == self.nope_blocks:
pair_idx = lane_id
if group_idx == self.nope_blocks:
compressed_pos = (position // Int64(self.compress_ratio)) * Int64(
self.compress_ratio
)
cos_v = cos_sin_cache[compressed_pos, pair_idx]
sin_v = cos_sin_cache[
compressed_pos, pair_idx + Int32(self.rope_dim // 2)
]
real = x0 * cos_v - x1 * sin_v
imag = x0 * sin_v + x1 * cos_v
packed = _fp32x2_to_bf16x2(real, imag)
out_base = value_base + Int64(self.nope_dim) + (lane_id * 4).to(Int64)
k_cache_u32.iterator[out_base // Int64(4)] = packed
for pair in cutlass.range_constexpr(self.elems_per_lane // 2):
elem = const_expr(pair * 2)
pair_idx = (elem_base - self.nope_dim) // 2 + Int32(pair)
cos_v = cos_sin_cache[compressed_pos, pair_idx]
sin_v = cos_sin_cache[
compressed_pos, pair_idx + Int32(self.rope_dim // 2)
]
real = x[elem] * cos_v - x[elem + 1] * sin_v
imag = x[elem] * sin_v + x[elem + 1] * cos_v
packed = _fp32x2_to_bf16x2(real, imag)
out_base = (
value_base
+ Int64(self.nope_dim)
+ ((elem_base - self.nope_dim + Int32(elem)) * 2).to(Int64)
)
k_cache_u32.iterator[out_base // Int64(4)] = packed
else:
q_packed = _fp32x2_to_bf16x2(x0, x1)
q0, q1 = _bf16x2_to_fp32(q_packed)
abs0 = cute.math.absf(q0)
abs1 = cute.math.absf(q1)
local_absmax = cute.arch.fmax(abs0, abs1)
q = cute.make_rmem_tensor((self.elems_per_lane,), Float32)
local_absmax = Float32(0.0)
for pair in cutlass.range_constexpr(self.elems_per_lane // 2):
elem = const_expr(pair * 2)
q_packed = _fp32x2_to_bf16x2(x[elem], x[elem + 1])
q0, q1 = _bf16x2_to_fp32(q_packed)
q[elem] = q0
q[elem + 1] = q1
local_absmax = cute.arch.fmax(
local_absmax,
cute.arch.fmax(cute.math.absf(q0), cute.math.absf(q1)),
)
absmax = local_absmax
for step in cutlass.range_constexpr(5):
offset = const_expr(16 >> step)
group_mask_and_clamp = const_expr(
(cute.arch.WARP_SIZE - self.lanes_per_group) << 8
| (cute.arch.WARP_SIZE - 1)
)
for step in cutlass.range_constexpr(self.scale_reduce_steps):
offset = const_expr(self.scale_reduce_offset >> step)
absmax = cute.arch.fmax(
absmax, cute.arch.shuffle_sync_bfly(absmax, offset)
absmax,
cute.arch.shuffle_sync_bfly(
absmax,
offset=offset,
mask_and_clamp=group_mask_and_clamp,
),
)
scale_raw = cute.arch.fmax(
Float32(self.min_scale),
Expand All @@ -320,22 +368,22 @@ def kernel(
bits = _recast_val(scale_raw, Uint32)
ue8m0 = ((bits + Uint32(0x7FFFFF)) >> Uint32(23)) & Uint32(0xFF)
inv_scale = _recast_val((Uint32(254) - ue8m0) << Uint32(23), Float32)
y0 = cute.arch.fmin(
cute.arch.fmax(q0 * inv_scale, Float32(-self.fp8_max)),
Float32(self.fp8_max),
)
y1 = cute.arch.fmin(
cute.arch.fmax(q1 * inv_scale, Float32(-self.fp8_max)),
Float32(self.fp8_max),
)
packed_fp8 = _fp32x2_to_fp8e4m3x2(y0, y1)
out_base = value_base + (warp_id * self.quant_block + lane_id * 2).to(
Int64
)
k_cache_u16.iterator[out_base // Int64(2)] = packed_fp8
if lane_id == 0:
k_cache.iterator[scale_base + warp_id.to(Int64)] = ue8m0.to(Uint8)
if warp_id == 0:
for pair in cutlass.range_constexpr(self.elems_per_lane // 2):
elem = const_expr(pair * 2)
y0 = cute.arch.fmin(
cute.arch.fmax(q[elem] * inv_scale, Float32(-self.fp8_max)),
Float32(self.fp8_max),
)
y1 = cute.arch.fmin(
cute.arch.fmax(q[elem + 1] * inv_scale, Float32(-self.fp8_max)),
Float32(self.fp8_max),
)
packed_fp8 = _fp32x2_to_fp8e4m3x2(y0, y1)
out_base = value_base + (elem_base + Int32(elem)).to(Int64)
k_cache_u16.iterator[out_base // Int64(2)] = packed_fp8
if group_lane == 0:
k_cache.iterator[scale_base + group_idx.to(Int64)] = ue8m0.to(Uint8)
if group_idx == 0:
k_cache.iterator[scale_base + Int64(self.nope_blocks)] = Uint8(
0
)
Expand Down Expand Up @@ -462,11 +510,11 @@ def compile(

class SparseAttnCompressKernel:
head_tile = 64
rows_per_warp = 8
rows_per_warp = 16
row_pairs_per_warp = rows_per_warp // 2
elems_per_lane = 4
lanes_per_row = head_tile // elems_per_lane
num_warps = 16
num_warps = 8
stats_warp_stride = num_warps + 1
tb_size = num_warps * 32
rcp_ln2 = 1.4426950408889634
Expand Down Expand Up @@ -715,8 +763,8 @@ def kernel(

local_warp_max = s_max[out_lane, out_elem, final_lane]
global_max = local_warp_max
for step in cutlass.range_constexpr(4):
offset = const_expr(8 >> step)
for step in cutlass.range_constexpr(3):
offset = const_expr(4 >> step)
global_max = cute.arch.fmax(
global_max,
cute.arch.shuffle_sync_bfly(
Expand All @@ -732,8 +780,8 @@ def kernel(
)
global_sum = s_sum[out_lane, out_elem, final_lane] * scale
global_product = s_product[out_lane, out_elem, final_lane] * scale
for step in cutlass.range_constexpr(4):
offset = const_expr(8 >> step)
for step in cutlass.range_constexpr(3):
offset = const_expr(4 >> step)
global_sum += cute.arch.shuffle_sync_bfly(
global_sum,
offset=offset,
Expand Down
Loading