Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions python/sglang/srt/layers/attention/trtllm_fp8_kv_kernel.py
Comment thread
b8zhong marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def _fused_fp8_set_kv_buffer_kernel(
v_cache_ptr, # [total_slots, num_kv_heads, head_dim]
# Cache location indices
cache_loc_ptr, # [num_tokens] -> token to cache location mapping
# Scalar scale (if provided, will be used; otherwise computed per-token)
k_scale, # scalar float
v_scale, # scalar float
# Pointers to scalar inverse scales (computed on GPU in wrapper)
inv_k_scale_ptr, # pointer to 0-D tensor on GPU
inv_v_scale_ptr, # pointer to 0-D tensor on GPU
use_provided_scale: tl.constexpr, # whether to use provided scale
# Tensor dimensions
num_kv_heads: tl.constexpr,
Expand Down Expand Up @@ -147,7 +147,10 @@ def _fused_fp8_set_kv_buffer_kernel(
# Select K or V based on kv_idx
if kv_idx == 0:
# Process K tensor
inv_scale = 1.0 / k_scale if use_provided_scale else 1.0
if use_provided_scale:
inv_scale = tl.load(inv_k_scale_ptr)
else:
inv_scale = 1.0
_process_kv_tensor(
token_id,
head_block_id,
Expand All @@ -171,7 +174,10 @@ def _fused_fp8_set_kv_buffer_kernel(
)
else:
# Process V tensor
inv_scale = 1.0 / v_scale if use_provided_scale else 1.0
if use_provided_scale:
inv_scale = tl.load(inv_v_scale_ptr)
else:
inv_scale = 1.0
_process_kv_tensor(
token_id,
head_block_id,
Expand Down Expand Up @@ -343,15 +349,46 @@ def fused_fp8_set_kv_buffer(
# - dim 2: K/V (0=K, 1=V)
grid = (num_tokens, num_head_blocks, 2)

device = k_3d.device

def _to_tensor_scale(scale):
"""Convert scale to 0-D CUDA tensor (accepts Python float or Tensor)."""
if isinstance(scale, torch.Tensor):
return scale.to(device=device, dtype=torch.float32)
else:
# Python float / np scalar
return torch.tensor(float(scale), device=device, dtype=torch.float32)

# Compute inverse scales on GPU to avoid GPU→CPU sync in CUDA graph capture.
# Previously we used float(k_scale) which triggers synchronization and fails
# during CUDA graph capture with cudaErrorStreamCaptureUnsupported.
if use_provided_scale:
k_scale_tensor = _to_tensor_scale(k_scale)
v_scale_tensor = _to_tensor_scale(v_scale)

# Pure GPU scalar operation, safe for CUDA graph
inv_k_scale = (1.0 / k_scale_tensor).to(device=device, dtype=torch.float32)
inv_v_scale = (1.0 / v_scale_tensor).to(device=device, dtype=torch.float32)

inv_k_scale_ptr = inv_k_scale
inv_v_scale_ptr = inv_v_scale
else:
# When use_provided_scale=False, kernel uses constant 1.0 for inv_scale.
# Triton will optimize away the tl.load() calls via constant folding.
# We pass dummy pointers (k_3d) which won't be accessed in the kernel.
# This avoids creating new GPU tensors during CUDA graph capture.
inv_k_scale_ptr = k_3d
inv_v_scale_ptr = k_3d

# Launch Triton kernel
_fused_fp8_set_kv_buffer_kernel[grid](
k_3d,
v_3d,
k_cache,
v_cache,
cache_loc,
k_scale if k_scale is not None else 1.0,
v_scale if v_scale is not None else 1.0,
inv_k_scale_ptr,
inv_v_scale_ptr,
use_provided_scale,
num_kv_heads,
head_dim,
Expand Down
173 changes: 173 additions & 0 deletions test/manual/test_trtllm_fp8_kv_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,179 @@ def test_empty_input(self):
page_size=page_size,
)

def test_fp8_kv_kernel_accepts_tensor_scales(self):
"""
Regression test for B200 Triton compilation issue.

This test ensures that fused_fp8_set_kv_buffer correctly handles
k_scale/v_scale when they are 0-dimensional tensors (torch.nn.Parameter).

Previously, Triton would treat 0-D tensor arguments as pointers,
causing a type error when performing "1.0 / k_scale" inside the kernel.
The fix converts tensor scales to Python floats in the wrapper.
"""
device = torch.device("cuda")

num_tokens = 4
num_kv_heads = 2
head_dim = 64
page_size = 16
total_slots = page_size

k = torch.randn(
num_tokens, num_kv_heads, head_dim, device=device, dtype=torch.bfloat16
)
v = torch.randn_like(k)

k_cache = torch.empty(
total_slots,
num_kv_heads,
head_dim,
device=device,
dtype=torch.float8_e4m3fn,
)
v_cache = torch.empty_like(k_cache)

cache_loc = torch.arange(num_tokens, device=device, dtype=torch.int32)

# Use 0D tensor form of scale to reproduce the original bug scenario
k_scale = torch.tensor(1.0, device=device, dtype=torch.float32)
v_scale = torch.tensor(1.0, device=device, dtype=torch.float32)

# Old code would trigger Triton's IncompatibleTypeError here
# New code should handle this gracefully by converting to float
fused_fp8_set_kv_buffer(
k,
v,
k_cache,
v_cache,
cache_loc,
k_scale=k_scale,
v_scale=v_scale,
page_size=page_size,
use_triton=True,
)

# If we get here without exception, the regression is fixed

def test_fp8_kv_kernel_cuda_graph_compatible(self):
"""
Regression test for CUDA graph capture compatibility.

This test ensures that fused_fp8_set_kv_buffer works correctly within
CUDA graph capture, which is used in production for performance.

Previously, float(k_scale) caused GPU→CPU synchronization, triggering
cudaErrorStreamCaptureUnsupported during graph capture. The fix computes
inverse scales purely on GPU using tensor operations.
"""
device = torch.device("cuda")

num_tokens = 4
num_kv_heads = 2
head_dim = 64
page_size = 16
total_slots = page_size

k = torch.randn(
num_tokens, num_kv_heads, head_dim, device=device, dtype=torch.bfloat16
)
v = torch.randn_like(k)

k_cache = torch.empty(
total_slots,
num_kv_heads,
head_dim,
device=device,
dtype=torch.float8_e4m3fn,
)
v_cache = torch.empty_like(k_cache)

cache_loc = torch.arange(num_tokens, device=device, dtype=torch.int32)

# Use 0D tensor scales (like nn.Parameter) to reproduce production scenario
k_scale = torch.tensor(1.0, device=device, dtype=torch.float32)
v_scale = torch.tensor(1.0, device=device, dtype=torch.float32)

# Test that kernel works under CUDA graph capture
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
# Old code would fail here with cudaErrorStreamCaptureUnsupported
# New code should succeed because all operations stay on GPU
fused_fp8_set_kv_buffer(
k,
v,
k_cache,
v_cache,
cache_loc,
k_scale=k_scale,
v_scale=v_scale,
page_size=page_size,
use_triton=True,
)

# Replay the graph to verify it works
graph.replay()

# If we get here without exception, CUDA graph compatibility is confirmed

def test_fp8_kv_kernel_cuda_graph_compatible_no_scale(self):
"""
Regression test for CUDA graph capture compatibility without scales.

This test ensures that fused_fp8_set_kv_buffer works correctly within
CUDA graph capture when k_scale/v_scale are None (use_provided_scale=False).

Previously, the code created new GPU tensors (torch.tensor(1.0, device=...))
during graph capture, triggering cudaErrorStreamCaptureUnsupported.
The fix passes dummy pointers when use_provided_scale=False, as the kernel
uses constant 1.0 and Triton optimizes away the pointer loads.
"""
device = torch.device("cuda")

num_tokens = 4
num_kv_heads = 2
head_dim = 64
page_size = 16
total_slots = page_size

k = torch.randn(
num_tokens, num_kv_heads, head_dim, device=device, dtype=torch.bfloat16
)
v = torch.randn_like(k)

k_cache = torch.empty(
total_slots,
num_kv_heads,
head_dim,
device=device,
dtype=torch.float8_e4m3fn,
)
v_cache = torch.empty_like(k_cache)

cache_loc = torch.arange(num_tokens, device=device, dtype=torch.int32)

# Test that kernel works under CUDA graph capture WITHOUT scales
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
# No k_scale/v_scale provided - use_provided_scale=False branch
# Old code would fail here with cudaErrorStreamCaptureUnsupported
# New code should succeed by using dummy pointers
fused_fp8_set_kv_buffer(
k,
v,
k_cache,
v_cache,
cache_loc,
page_size=page_size,
use_triton=True,
)

# Replay the graph to verify it works
graph.replay()

# If we get here without exception, no-scale CUDA graph compatibility is confirmed


if __name__ == "__main__":
unittest.main()
Loading