diff --git a/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py b/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py index 0eba82126afd..1e6c2ed829df 100644 --- a/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py +++ b/vllm/models/deepseek_v4/nvidia/ops/sparse_attn_compress_cutedsl.py @@ -96,9 +96,16 @@ def __init__( self.quant_block = quant_block self.token_stride = token_stride self.scale_dim = scale_dim - self.num_warps = head_size // quant_block + self.elems_per_lane = 8 + self.copy_elems = 4 + self.copy_chunks = self.elems_per_lane // self.copy_elems + self.lanes_per_group = quant_block // self.elems_per_lane + self.groups_per_warp = 32 // self.lanes_per_group + self.scale_reduce_steps = self.lanes_per_group.bit_length() - 1 + self.scale_reduce_offset = self.lanes_per_group // 2 + self.num_warps = (head_size // quant_block) // self.groups_per_warp self.nope_blocks = self.nope_dim // quant_block - self.tb_size = head_size // 2 + self.tb_size = self.num_warps * 32 self.compress_ratio = compress_ratio self.overlap = overlap self.window = (1 + int(overlap)) * compress_ratio @@ -156,8 +163,9 @@ def kernel( tid, _, _ = cute.arch.thread_idx() warp_id = cute.arch.make_warp_uniform(tid // 32) lane_id = tid % 32 - elem0 = tid * 2 - elem1 = elem0 + 1 + group_lane = lane_id % self.lanes_per_group + group_idx = warp_id * self.groups_per_warp + lane_id // self.lanes_per_group + elem_base = group_idx * self.quant_block + group_lane * self.elems_per_lane slot_id = slot_mapping[token_idx] has_position = token_idx < positions.shape[0] @@ -201,12 +209,24 @@ def kernel( s_block_numbers[row] = block_number_i32 cute.arch.sync_threads() - max0 = -Float32.inf - max1 = -Float32.inf - sum0 = Float32(0.0) - sum1 = Float32(0.0) - product0 = Float32(0.0) - product1 = Float32(0.0) + local_max = cute.make_rmem_tensor((self.elems_per_lane,), Float32) + local_sum = cute.make_rmem_tensor((self.elems_per_lane,), Float32) + local_product = cute.make_rmem_tensor((self.elems_per_lane,), Float32) + + for e in cutlass.range_constexpr(self.elems_per_lane): + local_max[e] = -Float32.inf + local_sum[e] = Float32(0.0) + local_product[e] = Float32(0.0) + + cp_f32x4 = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), Float32, num_bits_per_copy=128 + ) + copy_layout = cute.make_layout( + (self.copy_chunks, self.copy_elems), + stride=(self.copy_elems, 1), + ) + kv_vals = cute.make_rmem_tensor(copy_layout, Float32) + score_vals = cute.make_rmem_tensor(copy_layout, Float32) for row in cutlass.range_constexpr(self.window): pos = start + Int64(row) @@ -215,46 +235,51 @@ def kernel( block_offset = pos - block_index * block_size block_number = s_block_numbers[row].to(Int64) head_offset = Int64((row // self.compress_ratio) * self.head_dim) - row_base = ( - block_number * state_cache.stride[0] - + block_offset * state_cache.stride[1] - + head_offset - ) - - score0 = state_cache.iterator[ - row_base + Int64(self.state_width) + elem0.to(Int64) - ] - kv0 = state_cache.iterator[row_base + elem0.to(Int64)] - new_max0 = cute.arch.fmax(max0, score0) - old_scale0 = cute.math.exp2( - (max0 - new_max0) * Float32(self.rcp_ln2), fastmath=True - ) - new_scale0 = cute.math.exp2( - (score0 - new_max0) * Float32(self.rcp_ln2), fastmath=True - ) - sum0 = sum0 * old_scale0 + new_scale0 - product0 = product0 * old_scale0 + kv0 * new_scale0 - max0 = new_max0 - - score1 = state_cache.iterator[ - row_base + Int64(self.state_width) + elem1.to(Int64) - ] - kv1 = state_cache.iterator[row_base + elem1.to(Int64)] - new_max1 = cute.arch.fmax(max1, score1) - old_scale1 = cute.math.exp2( - (max1 - new_max1) * Float32(self.rcp_ln2), fastmath=True - ) - new_scale1 = cute.math.exp2( - (score1 - new_max1) * Float32(self.rcp_ln2), fastmath=True - ) - sum1 = sum1 * old_scale1 + new_scale1 - product1 = product1 * old_scale1 + kv1 * new_scale1 - max1 = new_max1 + row_tensor = state_cache[block_number, block_offset, None] + for chunk in cutlass.range_constexpr(self.copy_chunks): + copy_elem = const_expr(chunk * self.copy_elems) + col_tile = ( + head_offset + (elem_base + Int32(copy_elem)).to(Int64) + ) // Int64(self.copy_elems) + kv_src = cute.local_tile( + row_tensor, + tiler=(self.copy_elems,), + coord=(col_tile,), + ) + score_src = cute.local_tile( + row_tensor, + tiler=(self.copy_elems,), + coord=( + col_tile + Int64(self.state_width // self.copy_elems), + ), + ) + cute.copy(cp_f32x4, kv_src, kv_vals[chunk, None]) + cute.copy(cp_f32x4, score_src, score_vals[chunk, None]) + + for e in cutlass.range_constexpr(self.elems_per_lane): + chunk = const_expr(e // self.copy_elems) + copy_elem = const_expr(e % self.copy_elems) + score = score_vals[chunk, copy_elem] + kv = kv_vals[chunk, copy_elem] + new_max = cute.arch.fmax(local_max[e], score) + old_scale = cute.math.exp2( + (local_max[e] - new_max) * Float32(self.rcp_ln2), + fastmath=True, + ) + new_scale = cute.math.exp2( + (score - new_max) * Float32(self.rcp_ln2), + fastmath=True, + ) + local_sum[e] = local_sum[e] * old_scale + new_scale + local_product[e] = local_product[e] * old_scale + kv * new_scale + local_max[e] = new_max - x0 = product0 / sum0 - x1 = product1 / sum1 + x = cute.make_rmem_tensor((self.elems_per_lane,), Float32) + local_sumsq = Float32(0.0) + for e in cutlass.range_constexpr(self.elems_per_lane): + x[e] = local_product[e] / local_sum[e] + local_sumsq += x[e] * x[e] - local_sumsq = x0 * x0 + x1 * x1 warp_sum = local_sumsq for step in cutlass.range_constexpr(5): offset = const_expr(16 >> step) @@ -273,8 +298,9 @@ def kernel( cute.arch.sync_threads() rrms = rrms_shared[0] - x0 = x0 * rrms * rms_norm_weight[elem0].to(Float32) - x1 = x1 * rrms * rms_norm_weight[elem1].to(Float32) + for e in cutlass.range_constexpr(self.elems_per_lane): + elem = elem_base + e + x[e] = x[e] * rrms * rms_norm_weight[elem].to(Float32) k_cache_u16 = cute.recast_tensor(k_cache, Uint16) k_cache_u32 = cute.recast_tensor(k_cache, Uint32) @@ -287,31 +313,53 @@ def kernel( + kv_offset * Int64(self.scale_dim) ) - if warp_id == self.nope_blocks: - pair_idx = lane_id + if group_idx == self.nope_blocks: compressed_pos = (position // Int64(self.compress_ratio)) * Int64( self.compress_ratio ) - cos_v = cos_sin_cache[compressed_pos, pair_idx] - sin_v = cos_sin_cache[ - compressed_pos, pair_idx + Int32(self.rope_dim // 2) - ] - real = x0 * cos_v - x1 * sin_v - imag = x0 * sin_v + x1 * cos_v - packed = _fp32x2_to_bf16x2(real, imag) - out_base = value_base + Int64(self.nope_dim) + (lane_id * 4).to(Int64) - k_cache_u32.iterator[out_base // Int64(4)] = packed + for pair in cutlass.range_constexpr(self.elems_per_lane // 2): + elem = const_expr(pair * 2) + pair_idx = (elem_base - self.nope_dim) // 2 + Int32(pair) + cos_v = cos_sin_cache[compressed_pos, pair_idx] + sin_v = cos_sin_cache[ + compressed_pos, pair_idx + Int32(self.rope_dim // 2) + ] + real = x[elem] * cos_v - x[elem + 1] * sin_v + imag = x[elem] * sin_v + x[elem + 1] * cos_v + packed = _fp32x2_to_bf16x2(real, imag) + out_base = ( + value_base + + Int64(self.nope_dim) + + ((elem_base - self.nope_dim + Int32(elem)) * 2).to(Int64) + ) + k_cache_u32.iterator[out_base // Int64(4)] = packed else: - q_packed = _fp32x2_to_bf16x2(x0, x1) - q0, q1 = _bf16x2_to_fp32(q_packed) - abs0 = cute.math.absf(q0) - abs1 = cute.math.absf(q1) - local_absmax = cute.arch.fmax(abs0, abs1) + q = cute.make_rmem_tensor((self.elems_per_lane,), Float32) + local_absmax = Float32(0.0) + for pair in cutlass.range_constexpr(self.elems_per_lane // 2): + elem = const_expr(pair * 2) + q_packed = _fp32x2_to_bf16x2(x[elem], x[elem + 1]) + q0, q1 = _bf16x2_to_fp32(q_packed) + q[elem] = q0 + q[elem + 1] = q1 + local_absmax = cute.arch.fmax( + local_absmax, + cute.arch.fmax(cute.math.absf(q0), cute.math.absf(q1)), + ) absmax = local_absmax - for step in cutlass.range_constexpr(5): - offset = const_expr(16 >> step) + group_mask_and_clamp = const_expr( + (cute.arch.WARP_SIZE - self.lanes_per_group) << 8 + | (cute.arch.WARP_SIZE - 1) + ) + for step in cutlass.range_constexpr(self.scale_reduce_steps): + offset = const_expr(self.scale_reduce_offset >> step) absmax = cute.arch.fmax( - absmax, cute.arch.shuffle_sync_bfly(absmax, offset) + absmax, + cute.arch.shuffle_sync_bfly( + absmax, + offset=offset, + mask_and_clamp=group_mask_and_clamp, + ), ) scale_raw = cute.arch.fmax( Float32(self.min_scale), @@ -320,22 +368,22 @@ def kernel( bits = _recast_val(scale_raw, Uint32) ue8m0 = ((bits + Uint32(0x7FFFFF)) >> Uint32(23)) & Uint32(0xFF) inv_scale = _recast_val((Uint32(254) - ue8m0) << Uint32(23), Float32) - y0 = cute.arch.fmin( - cute.arch.fmax(q0 * inv_scale, Float32(-self.fp8_max)), - Float32(self.fp8_max), - ) - y1 = cute.arch.fmin( - cute.arch.fmax(q1 * inv_scale, Float32(-self.fp8_max)), - Float32(self.fp8_max), - ) - packed_fp8 = _fp32x2_to_fp8e4m3x2(y0, y1) - out_base = value_base + (warp_id * self.quant_block + lane_id * 2).to( - Int64 - ) - k_cache_u16.iterator[out_base // Int64(2)] = packed_fp8 - if lane_id == 0: - k_cache.iterator[scale_base + warp_id.to(Int64)] = ue8m0.to(Uint8) - if warp_id == 0: + for pair in cutlass.range_constexpr(self.elems_per_lane // 2): + elem = const_expr(pair * 2) + y0 = cute.arch.fmin( + cute.arch.fmax(q[elem] * inv_scale, Float32(-self.fp8_max)), + Float32(self.fp8_max), + ) + y1 = cute.arch.fmin( + cute.arch.fmax(q[elem + 1] * inv_scale, Float32(-self.fp8_max)), + Float32(self.fp8_max), + ) + packed_fp8 = _fp32x2_to_fp8e4m3x2(y0, y1) + out_base = value_base + (elem_base + Int32(elem)).to(Int64) + k_cache_u16.iterator[out_base // Int64(2)] = packed_fp8 + if group_lane == 0: + k_cache.iterator[scale_base + group_idx.to(Int64)] = ue8m0.to(Uint8) + if group_idx == 0: k_cache.iterator[scale_base + Int64(self.nope_blocks)] = Uint8( 0 ) @@ -462,11 +510,11 @@ def compile( class SparseAttnCompressKernel: head_tile = 64 - rows_per_warp = 8 + rows_per_warp = 16 row_pairs_per_warp = rows_per_warp // 2 elems_per_lane = 4 lanes_per_row = head_tile // elems_per_lane - num_warps = 16 + num_warps = 8 stats_warp_stride = num_warps + 1 tb_size = num_warps * 32 rcp_ln2 = 1.4426950408889634 @@ -715,8 +763,8 @@ def kernel( local_warp_max = s_max[out_lane, out_elem, final_lane] global_max = local_warp_max - for step in cutlass.range_constexpr(4): - offset = const_expr(8 >> step) + for step in cutlass.range_constexpr(3): + offset = const_expr(4 >> step) global_max = cute.arch.fmax( global_max, cute.arch.shuffle_sync_bfly( @@ -732,8 +780,8 @@ def kernel( ) global_sum = s_sum[out_lane, out_elem, final_lane] * scale global_product = s_product[out_lane, out_elem, final_lane] * scale - for step in cutlass.range_constexpr(4): - offset = const_expr(8 >> step) + for step in cutlass.range_constexpr(3): + offset = const_expr(4 >> step) global_sum += cute.arch.shuffle_sync_bfly( global_sum, offset=offset,