diff --git a/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py b/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py index 97c9538889a1..84647d6120d8 100644 --- a/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py +++ b/vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py @@ -10,6 +10,7 @@ import torch from vllm.triton_utils import tl, triton +from vllm.utils.torch_utils import direct_register_custom_op @triton.jit @@ -180,34 +181,74 @@ def fused_inv_rope_fp8_quant( fp8_dtype = torch.float8_e4m3fn fp8_max = torch.finfo(fp8_dtype).max - fp8_buf = torch.empty( - (n_groups, num_tokens, d), - dtype=fp8_dtype, - device=o.device, - ) - tma_aligned_T = get_tma_aligned_size(num_tokens, 4) if tma_aligned_scales: packed_sf_k = (num_scale_blocks + 3) // 4 - scale_buf = torch.empty( - n_groups * packed_sf_k * tma_aligned_T, - dtype=torch.int32, - device=o.device, - ).as_strided( - (n_groups, num_tokens, packed_sf_k), - (packed_sf_k * tma_aligned_T, 1, tma_aligned_T), - ) + scale_inner = packed_sf_k else: - scale_buf = torch.empty( - n_groups * num_scale_blocks * tma_aligned_T, - dtype=torch.float32, - device=o.device, - ).as_strided( - (n_groups, num_tokens, num_scale_blocks), - (num_scale_blocks * tma_aligned_T, 1, tma_aligned_T), - ) + scale_inner = num_scale_blocks + + # Run kernel through a custom op so inductor sees an opaque boundary. + # It's a pytorch bug, see https://github.com/vllm-project/vllm/issues/41106 + fp8_buf, scale_buf = torch.ops.vllm.fused_inv_rope_fp8_quant_kernel( + o, + positions, + cos_sin_cache, + heads_per_group, + quant_group_size, + chunks_per_head, + nope_dim % quant_group_size, + rope_dim // 2, + tma_aligned_scales, + fp8_max, + tma_aligned_T, + num_tokens, + n_groups, + d, + scale_inner, + ) + return fp8_buf.transpose(0, 1), scale_buf.transpose(0, 1) + - common_args = dict( +def _fused_inv_rope_fp8_quant_kernel_impl( + o: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + heads_per_group: int, + quant_group_size: int, + chunks_per_head: int, + rope_start: int, + half_rope: int, + tma_aligned_scales: bool, + fp8_max: float, + tma_aligned_T: int, + num_tokens: int, + n_groups: int, + d: int, + scale_inner: int, +) -> tuple[torch.Tensor, torch.Tensor]: + fp8_buf = torch.empty( + (n_groups, num_tokens, d), + dtype=torch.float8_e4m3fn, + device=o.device, + ) + scale_dtype = torch.int32 if tma_aligned_scales else torch.float32 + scale_buf = torch.empty( + n_groups * scale_inner * tma_aligned_T, + dtype=scale_dtype, + device=o.device, + ).as_strided( + (n_groups, num_tokens, scale_inner), + (scale_inner * tma_aligned_T, 1, tma_aligned_T), + ) + grid = (tma_aligned_T, n_groups * heads_per_group) + _fused_inv_rope_fp8_quant_per_head[grid]( + o, + positions, + cos_sin_cache, + fp8_buf, + scale_buf, + num_tokens, heads_per_group=heads_per_group, o_stride_token=o.stride(0), o_stride_head=o.stride(1), @@ -220,23 +261,52 @@ def fused_inv_rope_fp8_quant( eps=1e-10, QUANT_GROUP_SIZE=quant_group_size, CHUNKS_PER_HEAD=chunks_per_head, - ROPE_START=nope_dim % quant_group_size, - HALF_ROPE=rope_dim // 2, + ROPE_START=rope_start, + HALF_ROPE=half_rope, TMA_ALIGNED_SCALES=tma_aligned_scales, num_stages=1, launch_pdl=False, + num_warps=1, ) + return fp8_buf, scale_buf - grid = (tma_aligned_T, n_groups * heads_per_group) - _fused_inv_rope_fp8_quant_per_head[grid]( - o, - positions, - cos_sin_cache, - fp8_buf, - scale_buf, - num_tokens, - **common_args, - num_warps=1, + +def _fused_inv_rope_fp8_quant_kernel_fake( + o: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + heads_per_group: int, + quant_group_size: int, + chunks_per_head: int, + rope_start: int, + half_rope: int, + tma_aligned_scales: bool, + fp8_max: float, + tma_aligned_T: int, + num_tokens: int, + n_groups: int, + d: int, + scale_inner: int, +) -> tuple[torch.Tensor, torch.Tensor]: + fp8_buf = torch.empty( + (n_groups, num_tokens, d), + dtype=torch.float8_e4m3fn, + device=o.device, ) + scale_dtype = torch.int32 if tma_aligned_scales else torch.float32 + scale_buf = torch.empty( + n_groups * scale_inner * tma_aligned_T, + dtype=scale_dtype, + device=o.device, + ).as_strided( + (n_groups, num_tokens, scale_inner), + (scale_inner * tma_aligned_T, 1, tma_aligned_T), + ) + return fp8_buf, scale_buf - return fp8_buf.transpose(0, 1), scale_buf.transpose(0, 1) + +direct_register_custom_op( + op_name="fused_inv_rope_fp8_quant_kernel", + op_func=_fused_inv_rope_fp8_quant_kernel_impl, + fake_impl=_fused_inv_rope_fp8_quant_kernel_fake, +)