Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 106 additions & 36 deletions vllm/v1/attention/ops/deepseek_v4_ops/fused_inv_rope_fp8_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch

from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import direct_register_custom_op


@triton.jit
Expand Down Expand Up @@ -180,34 +181,74 @@ def fused_inv_rope_fp8_quant(
fp8_dtype = torch.float8_e4m3fn
fp8_max = torch.finfo(fp8_dtype).max

fp8_buf = torch.empty(
(n_groups, num_tokens, d),
dtype=fp8_dtype,
device=o.device,
)

tma_aligned_T = get_tma_aligned_size(num_tokens, 4)
if tma_aligned_scales:
packed_sf_k = (num_scale_blocks + 3) // 4
scale_buf = torch.empty(
n_groups * packed_sf_k * tma_aligned_T,
dtype=torch.int32,
device=o.device,
).as_strided(
(n_groups, num_tokens, packed_sf_k),
(packed_sf_k * tma_aligned_T, 1, tma_aligned_T),
)
scale_inner = packed_sf_k
else:
scale_buf = torch.empty(
n_groups * num_scale_blocks * tma_aligned_T,
dtype=torch.float32,
device=o.device,
).as_strided(
(n_groups, num_tokens, num_scale_blocks),
(num_scale_blocks * tma_aligned_T, 1, tma_aligned_T),
)
scale_inner = num_scale_blocks

# Run kernel through a custom op so inductor sees an opaque boundary.
# It's a pytorch bug, see https://github.com/vllm-project/vllm/issues/41106
fp8_buf, scale_buf = torch.ops.vllm.fused_inv_rope_fp8_quant_kernel(
o,
positions,
cos_sin_cache,
heads_per_group,
quant_group_size,
chunks_per_head,
nope_dim % quant_group_size,
rope_dim // 2,
tma_aligned_scales,
fp8_max,
tma_aligned_T,
num_tokens,
n_groups,
d,
scale_inner,
)
return fp8_buf.transpose(0, 1), scale_buf.transpose(0, 1)


common_args = dict(
def _fused_inv_rope_fp8_quant_kernel_impl(
o: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
heads_per_group: int,
quant_group_size: int,
chunks_per_head: int,
rope_start: int,
half_rope: int,
tma_aligned_scales: bool,
fp8_max: float,
tma_aligned_T: int,
num_tokens: int,
n_groups: int,
d: int,
scale_inner: int,
) -> tuple[torch.Tensor, torch.Tensor]:
fp8_buf = torch.empty(
(n_groups, num_tokens, d),
dtype=torch.float8_e4m3fn,
device=o.device,
)
scale_dtype = torch.int32 if tma_aligned_scales else torch.float32
scale_buf = torch.empty(
n_groups * scale_inner * tma_aligned_T,
dtype=scale_dtype,
device=o.device,
).as_strided(
(n_groups, num_tokens, scale_inner),
(scale_inner * tma_aligned_T, 1, tma_aligned_T),
)
grid = (tma_aligned_T, n_groups * heads_per_group)
_fused_inv_rope_fp8_quant_per_head[grid](
o,
positions,
cos_sin_cache,
fp8_buf,
scale_buf,
num_tokens,
heads_per_group=heads_per_group,
o_stride_token=o.stride(0),
o_stride_head=o.stride(1),
Expand All @@ -220,23 +261,52 @@ def fused_inv_rope_fp8_quant(
eps=1e-10,
QUANT_GROUP_SIZE=quant_group_size,
CHUNKS_PER_HEAD=chunks_per_head,
ROPE_START=nope_dim % quant_group_size,
HALF_ROPE=rope_dim // 2,
ROPE_START=rope_start,
HALF_ROPE=half_rope,
TMA_ALIGNED_SCALES=tma_aligned_scales,
num_stages=1,
launch_pdl=False,
num_warps=1,
)
return fp8_buf, scale_buf

grid = (tma_aligned_T, n_groups * heads_per_group)
_fused_inv_rope_fp8_quant_per_head[grid](
o,
positions,
cos_sin_cache,
fp8_buf,
scale_buf,
num_tokens,
**common_args,
num_warps=1,

def _fused_inv_rope_fp8_quant_kernel_fake(
o: torch.Tensor,
positions: torch.Tensor,
cos_sin_cache: torch.Tensor,
heads_per_group: int,
quant_group_size: int,
chunks_per_head: int,
rope_start: int,
half_rope: int,
tma_aligned_scales: bool,
fp8_max: float,
tma_aligned_T: int,
num_tokens: int,
n_groups: int,
d: int,
scale_inner: int,
) -> tuple[torch.Tensor, torch.Tensor]:
fp8_buf = torch.empty(
(n_groups, num_tokens, d),
dtype=torch.float8_e4m3fn,
device=o.device,
)
scale_dtype = torch.int32 if tma_aligned_scales else torch.float32
scale_buf = torch.empty(
n_groups * scale_inner * tma_aligned_T,
dtype=scale_dtype,
device=o.device,
).as_strided(
(n_groups, num_tokens, scale_inner),
(scale_inner * tma_aligned_T, 1, tma_aligned_T),
)
return fp8_buf, scale_buf

return fp8_buf.transpose(0, 1), scale_buf.transpose(0, 1)

direct_register_custom_op(
op_name="fused_inv_rope_fp8_quant_kernel",
op_func=_fused_inv_rope_fp8_quant_kernel_impl,
fake_impl=_fused_inv_rope_fp8_quant_kernel_fake,
)
Loading