Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions tests/kernels/test_fused_indexer_q_rope_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,17 @@
Expects bit-exact equality on both q_fp8 and weights_out.
"""

import contextlib
from unittest import mock

import pytest
import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
per_token_group_quant_fp8,
)
from vllm.utils.import_utils import has_cutedsl
from vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q import (
fused_indexer_q_rope_quant,
)
Expand Down Expand Up @@ -125,8 +129,14 @@ def _reference(
@pytest.mark.parametrize("num_tokens", [1, 7, 32, 257, 1023])
@pytest.mark.parametrize("cache_dtype", [torch.float32, torch.bfloat16])
@pytest.mark.parametrize("use_fp4", [False, True])
@pytest.mark.parametrize("use_cutedsl", [False, True])
@torch.inference_mode()
def test_fused_indexer_q_rope_quant_matches_unfused(num_tokens, cache_dtype, use_fp4):
def test_fused_indexer_q_rope_quant_matches_unfused(
num_tokens, cache_dtype, use_fp4, use_cutedsl
):
if use_cutedsl and not has_cutedsl():
pytest.skip("cutedsl (cutlass) not installed")

device = "cuda"
torch.manual_seed(0)

Expand All @@ -142,9 +152,26 @@ def test_fused_indexer_q_rope_quant_matches_unfused(num_tokens, cache_dtype, use
q_quant_ref, weights_ref = _reference(
positions, q, cos_sin_cache, weights, softmax_scale, head_scale, use_fp4
)
q_quant_fused, weights_fused = fused_indexer_q_rope_quant(
positions, q.clone(), cos_sin_cache, weights, softmax_scale, head_scale, use_fp4
# use_cutedsl=False: force the triton path even when cutedsl is installed
# by patching the dispatcher's has_cutedsl() binding to return False.
cutedsl_patch = (
mock.patch(
"vllm.v1.attention.ops.deepseek_v4_ops.fused_indexer_q.has_cutedsl",
return_value=False,
)
if not use_cutedsl
else contextlib.nullcontext()
)
with cutedsl_patch:
q_quant_fused, weights_fused = fused_indexer_q_rope_quant(
positions,
q.clone(),
cos_sin_cache,
weights,
softmax_scale,
head_scale,
use_fp4,
)

if use_fp4:
q_quant_ref, q_scale_ref = q_quant_ref
Expand Down
33 changes: 33 additions & 0 deletions vllm/v1/attention/ops/deepseek_v4_ops/cutedsl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,39 @@ def _fp8x4_to_bf16x4(x: Uint32, *, loc=None, ip=None) -> cute.TensorSSA:
return cute.TensorSSA(vec, 2, Uint32)


@dsl_user_op
def _fp32x4_to_fp8x4(
a0: Float32,
a1: Float32,
a2: Float32,
a3: Float32,
*,
loc=None,
ip=None,
) -> Uint32:
# Pack four FP32 values into one b32 of four e4m3 bytes, byte order
# {a0, a1, a2, a3} from low to high address.
out = llvm.inline_asm(
T.i32(),
[
a0.ir_value(loc=loc, ip=ip),
a1.ir_value(loc=loc, ip=ip),
a2.ir_value(loc=loc, ip=ip),
a3.ir_value(loc=loc, ip=ip),
],
"{\n\t"
".reg .b16 t0, t1;\n\t"
"cvt.rn.satfinite.e4m3x2.f32 t0, $2, $1;\n\t"
"cvt.rn.satfinite.e4m3x2.f32 t1, $4, $3;\n\t"
"mov.b32 $0, {t0, t1};\n\t"
"}\n",
"=r,f,f,f,f",
has_side_effects=False,
is_align_stack=False,
)
return Uint32(out)


@dsl_user_op
def _fp32x8_to_fp4x8(
vals: cute.Tensor,
Expand Down
57 changes: 37 additions & 20 deletions vllm/v1/attention/ops/deepseek_v4_ops/fused_indexer_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,24 +398,41 @@ def fused_indexer_q_rope_quant(
), index_weights_out

index_q_fp8 = torch.empty_like(index_q, dtype=torch.float8_e4m3fn)
_fused_indexer_q_rope_quant_kernel[(num_tokens, num_index_q_heads)](
positions,
index_q,
index_q.stride(0),
index_q.stride(1),
index_q_cos_sin_cache,
index_q_cos_sin_cache.stride(0),
index_q_cos_sin_cache.shape[-1] // 2,
index_q_fp8,
index_q_fp8.stride(0),
index_q_fp8.stride(1),
index_q_head_dim,
index_weights,
index_weights.stride(0),
index_weights_softmax_scale,
index_weights_head_scale,
index_weights_out,
index_weights_out.stride(0),
num_warps=1, # TODO: Tune this
)
if has_cutedsl():
# lazily import, otherwise some tests fail due to CUDA driver init failure.
from .fused_indexer_q_cutedsl import (
fused_indexer_q_rope_quant_fp8_cutedsl,
)

fused_indexer_q_rope_quant_fp8_cutedsl(
positions,
index_q,
index_q_cos_sin_cache,
index_weights,
index_weights_softmax_scale,
index_weights_head_scale,
index_q_fp8,
index_weights_out,
)
else:
_fused_indexer_q_rope_quant_kernel[(num_tokens, num_index_q_heads)](
positions,
index_q,
index_q.stride(0),
index_q.stride(1),
index_q_cos_sin_cache,
index_q_cos_sin_cache.stride(0),
index_q_cos_sin_cache.shape[-1] // 2,
index_q_fp8,
index_q_fp8.stride(0),
index_q_fp8.stride(1),
index_q_head_dim,
index_weights,
index_weights.stride(0),
index_weights_softmax_scale,
index_weights_head_scale,
index_weights_out,
index_weights_out.stride(0),
num_warps=1, # TODO: Tune this
)
return index_q_fp8, index_weights_out
Loading
Loading