diff --git a/python/sglang/srt/layers/attention/trtllm_fp8_kv_kernel.py b/python/sglang/srt/layers/attention/trtllm_fp8_kv_kernel.py index e10b2f9bc684..41cc9c9dc367 100644 --- a/python/sglang/srt/layers/attention/trtllm_fp8_kv_kernel.py +++ b/python/sglang/srt/layers/attention/trtllm_fp8_kv_kernel.py @@ -94,9 +94,9 @@ def _fused_fp8_set_kv_buffer_kernel( v_cache_ptr, # [total_slots, num_kv_heads, head_dim] # Cache location indices cache_loc_ptr, # [num_tokens] -> token to cache location mapping - # Scalar scale (if provided, will be used; otherwise computed per-token) - k_scale, # scalar float - v_scale, # scalar float + # Pointers to scalar inverse scales (computed on GPU in wrapper) + inv_k_scale_ptr, # pointer to 0-D tensor on GPU + inv_v_scale_ptr, # pointer to 0-D tensor on GPU use_provided_scale: tl.constexpr, # whether to use provided scale # Tensor dimensions num_kv_heads: tl.constexpr, @@ -147,7 +147,10 @@ def _fused_fp8_set_kv_buffer_kernel( # Select K or V based on kv_idx if kv_idx == 0: # Process K tensor - inv_scale = 1.0 / k_scale if use_provided_scale else 1.0 + if use_provided_scale: + inv_scale = tl.load(inv_k_scale_ptr) + else: + inv_scale = 1.0 _process_kv_tensor( token_id, head_block_id, @@ -171,7 +174,10 @@ def _fused_fp8_set_kv_buffer_kernel( ) else: # Process V tensor - inv_scale = 1.0 / v_scale if use_provided_scale else 1.0 + if use_provided_scale: + inv_scale = tl.load(inv_v_scale_ptr) + else: + inv_scale = 1.0 _process_kv_tensor( token_id, head_block_id, @@ -343,6 +349,37 @@ def fused_fp8_set_kv_buffer( # - dim 2: K/V (0=K, 1=V) grid = (num_tokens, num_head_blocks, 2) + device = k_3d.device + + def _to_tensor_scale(scale): + """Convert scale to 0-D CUDA tensor (accepts Python float or Tensor).""" + if isinstance(scale, torch.Tensor): + return scale.to(device=device, dtype=torch.float32) + else: + # Python float / np scalar + return torch.tensor(float(scale), device=device, dtype=torch.float32) + + # Compute inverse scales on GPU to avoid GPU→CPU sync in CUDA graph capture. + # Previously we used float(k_scale) which triggers synchronization and fails + # during CUDA graph capture with cudaErrorStreamCaptureUnsupported. + if use_provided_scale: + k_scale_tensor = _to_tensor_scale(k_scale) + v_scale_tensor = _to_tensor_scale(v_scale) + + # Pure GPU scalar operation, safe for CUDA graph + inv_k_scale = (1.0 / k_scale_tensor).to(device=device, dtype=torch.float32) + inv_v_scale = (1.0 / v_scale_tensor).to(device=device, dtype=torch.float32) + + inv_k_scale_ptr = inv_k_scale + inv_v_scale_ptr = inv_v_scale + else: + # When use_provided_scale=False, kernel uses constant 1.0 for inv_scale. + # Triton will optimize away the tl.load() calls via constant folding. + # We pass dummy pointers (k_3d) which won't be accessed in the kernel. + # This avoids creating new GPU tensors during CUDA graph capture. + inv_k_scale_ptr = k_3d + inv_v_scale_ptr = k_3d + # Launch Triton kernel _fused_fp8_set_kv_buffer_kernel[grid]( k_3d, @@ -350,8 +387,8 @@ def fused_fp8_set_kv_buffer( k_cache, v_cache, cache_loc, - k_scale if k_scale is not None else 1.0, - v_scale if v_scale is not None else 1.0, + inv_k_scale_ptr, + inv_v_scale_ptr, use_provided_scale, num_kv_heads, head_dim, diff --git a/test/manual/test_trtllm_fp8_kv_kernel.py b/test/manual/test_trtllm_fp8_kv_kernel.py index e980ac221110..b713747a2a0c 100644 --- a/test/manual/test_trtllm_fp8_kv_kernel.py +++ b/test/manual/test_trtllm_fp8_kv_kernel.py @@ -301,6 +301,179 @@ def test_empty_input(self): page_size=page_size, ) + def test_fp8_kv_kernel_accepts_tensor_scales(self): + """ + Regression test for B200 Triton compilation issue. + + This test ensures that fused_fp8_set_kv_buffer correctly handles + k_scale/v_scale when they are 0-dimensional tensors (torch.nn.Parameter). + + Previously, Triton would treat 0-D tensor arguments as pointers, + causing a type error when performing "1.0 / k_scale" inside the kernel. + The fix converts tensor scales to Python floats in the wrapper. + """ + device = torch.device("cuda") + + num_tokens = 4 + num_kv_heads = 2 + head_dim = 64 + page_size = 16 + total_slots = page_size + + k = torch.randn( + num_tokens, num_kv_heads, head_dim, device=device, dtype=torch.bfloat16 + ) + v = torch.randn_like(k) + + k_cache = torch.empty( + total_slots, + num_kv_heads, + head_dim, + device=device, + dtype=torch.float8_e4m3fn, + ) + v_cache = torch.empty_like(k_cache) + + cache_loc = torch.arange(num_tokens, device=device, dtype=torch.int32) + + # Use 0D tensor form of scale to reproduce the original bug scenario + k_scale = torch.tensor(1.0, device=device, dtype=torch.float32) + v_scale = torch.tensor(1.0, device=device, dtype=torch.float32) + + # Old code would trigger Triton's IncompatibleTypeError here + # New code should handle this gracefully by converting to float + fused_fp8_set_kv_buffer( + k, + v, + k_cache, + v_cache, + cache_loc, + k_scale=k_scale, + v_scale=v_scale, + page_size=page_size, + use_triton=True, + ) + + # If we get here without exception, the regression is fixed + + def test_fp8_kv_kernel_cuda_graph_compatible(self): + """ + Regression test for CUDA graph capture compatibility. + + This test ensures that fused_fp8_set_kv_buffer works correctly within + CUDA graph capture, which is used in production for performance. + + Previously, float(k_scale) caused GPU→CPU synchronization, triggering + cudaErrorStreamCaptureUnsupported during graph capture. The fix computes + inverse scales purely on GPU using tensor operations. + """ + device = torch.device("cuda") + + num_tokens = 4 + num_kv_heads = 2 + head_dim = 64 + page_size = 16 + total_slots = page_size + + k = torch.randn( + num_tokens, num_kv_heads, head_dim, device=device, dtype=torch.bfloat16 + ) + v = torch.randn_like(k) + + k_cache = torch.empty( + total_slots, + num_kv_heads, + head_dim, + device=device, + dtype=torch.float8_e4m3fn, + ) + v_cache = torch.empty_like(k_cache) + + cache_loc = torch.arange(num_tokens, device=device, dtype=torch.int32) + + # Use 0D tensor scales (like nn.Parameter) to reproduce production scenario + k_scale = torch.tensor(1.0, device=device, dtype=torch.float32) + v_scale = torch.tensor(1.0, device=device, dtype=torch.float32) + + # Test that kernel works under CUDA graph capture + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + # Old code would fail here with cudaErrorStreamCaptureUnsupported + # New code should succeed because all operations stay on GPU + fused_fp8_set_kv_buffer( + k, + v, + k_cache, + v_cache, + cache_loc, + k_scale=k_scale, + v_scale=v_scale, + page_size=page_size, + use_triton=True, + ) + + # Replay the graph to verify it works + graph.replay() + + # If we get here without exception, CUDA graph compatibility is confirmed + + def test_fp8_kv_kernel_cuda_graph_compatible_no_scale(self): + """ + Regression test for CUDA graph capture compatibility without scales. + + This test ensures that fused_fp8_set_kv_buffer works correctly within + CUDA graph capture when k_scale/v_scale are None (use_provided_scale=False). + + Previously, the code created new GPU tensors (torch.tensor(1.0, device=...)) + during graph capture, triggering cudaErrorStreamCaptureUnsupported. + The fix passes dummy pointers when use_provided_scale=False, as the kernel + uses constant 1.0 and Triton optimizes away the pointer loads. + """ + device = torch.device("cuda") + + num_tokens = 4 + num_kv_heads = 2 + head_dim = 64 + page_size = 16 + total_slots = page_size + + k = torch.randn( + num_tokens, num_kv_heads, head_dim, device=device, dtype=torch.bfloat16 + ) + v = torch.randn_like(k) + + k_cache = torch.empty( + total_slots, + num_kv_heads, + head_dim, + device=device, + dtype=torch.float8_e4m3fn, + ) + v_cache = torch.empty_like(k_cache) + + cache_loc = torch.arange(num_tokens, device=device, dtype=torch.int32) + + # Test that kernel works under CUDA graph capture WITHOUT scales + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + # No k_scale/v_scale provided - use_provided_scale=False branch + # Old code would fail here with cudaErrorStreamCaptureUnsupported + # New code should succeed by using dummy pointers + fused_fp8_set_kv_buffer( + k, + v, + k_cache, + v_cache, + cache_loc, + page_size=page_size, + use_triton=True, + ) + + # Replay the graph to verify it works + graph.replay() + + # If we get here without exception, no-scale CUDA graph compatibility is confirmed + if __name__ == "__main__": unittest.main()