-
Notifications
You must be signed in to change notification settings - Fork 1k
Added the cudnn backend Ragged KV Cache wrapper #2352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1396,6 +1396,17 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): | |
| if remove_cudnn: | ||
| backends.remove("cudnn") | ||
|
|
||
| if "cudnn-native" in backends: | ||
| remove_cudnn_native = False | ||
| if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ | ||
| torch.float8_e4m3fn, | ||
| torch.float8_e5m2, | ||
| ]: | ||
| print("[INFO] CUDNN-native backend does not support FP8. Skipping.") | ||
| remove_cudnn_native = True | ||
| if remove_cudnn_native: | ||
| backends.remove("cudnn-native") | ||
|
Comment on lines
+1399
to
+1408
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. π§© Analysis chainπ Script executed: # First, let's examine the file structure and imports
head -50 benchmarks/routines/attention.pyRepository: flashinfer-ai/flashinfer Length of output: 1213 π Script executed: # Search for CUDNN_AVAILABLE or similar constants in the file
rg "CUDNN_AVAILABLE|cudnn_available|is_cuda" benchmarks/routines/attention.pyRepository: flashinfer-ai/flashinfer Length of output: 1207 π Script executed: # Search for CUDNN_AVAILABLE across the entire codebase
rg "CUDNN_AVAILABLE" --type pyRepository: flashinfer-ai/flashinfer Length of output: 1417 π Script executed: # Check for torch.backends.cudnn availability checks
rg "torch\.backends\.cudnn|CUDNN|cudnn" benchmarks/routines/attention.py | head -30Repository: flashinfer-ai/flashinfer Length of output: 1323 π Script executed: # Let's look at the complete context around lines 1399-1408
sed -n '1380,1420p' benchmarks/routines/attention.pyRepository: flashinfer-ai/flashinfer Length of output: 1633 π Script executed: # Get exact line numbers where CUDNN_AVAILABLE check appears in relation to cudnn-native
rg -n "cudnn-native|CUDNN_AVAILABLE" benchmarks/routines/attention.py | grep -A 10 -B 5 "cudnn-native"Repository: flashinfer-ai/flashinfer Length of output: 829 π Script executed: # Get the complete cudnn and cudnn-native backend sections with line numbers
sed -n '1370,1450p' benchmarks/routines/attention.py | cat -n | sed 's/^[[:space:]]*//' | sed 's/^/1370+/'Repository: flashinfer-ai/flashinfer Length of output: 3843 π Script executed: # Check context around the first cudnn-native block (lines 784-798)
sed -n '770,810p' benchmarks/routines/attention.pyRepository: flashinfer-ai/flashinfer Length of output: 1801 π Script executed: # Check context around the second cudnn-native block (lines 1399-1408)
sed -n '1390,1415p' benchmarks/routines/attention.pyRepository: flashinfer-ai/flashinfer Length of output: 1030 Add cuDNN availability guard to cudnn-native backend filter. The decode routine (lines 1399-1408) checks for FP8 dtype compatibility but lacks the Suggested fix if "cudnn-native" in backends:
remove_cudnn_native = False
+ if not CUDNN_AVAILABLE or CUDNN_BACKEND_VERSION < 91701:
+ print("[INFO] cuDNN not available. Skipping cudnn-native backend.")
+ remove_cudnn_native = True
- if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
+ elif q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [
torch.float8_e4m3fn,
torch.float8_e5m2,
]:
print("[INFO] CUDNN-native backend does not support FP8. Skipping.")
remove_cudnn_native = Trueπ€ Prompt for AI Agents |
||
|
|
||
| if "cutlass" in backends: | ||
| remove_cutlass = False | ||
| if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2] or kv_dtype in [ | ||
|
|
@@ -1609,6 +1620,34 @@ def testBatchPrefillWithRaggedKVCacheWrapper(args): | |
| q_data_type=q_dtype, | ||
| kv_data_type=kv_dtype, | ||
| ) | ||
| elif backend == "cudnn": | ||
| # cuDNN uses NHD layout and the wrapper API | ||
| backend_wrappers[backend] = ( | ||
| flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( | ||
| workspace_buffer, | ||
| "NHD", | ||
| backend="cudnn", | ||
| ) | ||
| ) | ||
| backend_wrappers[backend].plan( | ||
| qo_indptr=q_indptr, | ||
| kv_indptr=k_indptr, | ||
| num_qo_heads=num_qo_heads, | ||
| num_kv_heads=num_kv_heads, | ||
| head_dim_qk=head_dim_qk, | ||
| head_dim_vo=head_dim_vo, | ||
| causal=causal, | ||
| sm_scale=scale, | ||
| q_data_type=q_dtype, | ||
| kv_data_type=kv_dtype, | ||
| o_data_type=q_dtype, | ||
| seq_lens=actual_seq_lens_kv_device, | ||
| seq_lens_q=actual_seq_lens_q_device, | ||
| max_token_per_sequence=s_qo, | ||
| max_sequence_kv=s_kv, | ||
| v_indptr=v_indptr, | ||
| o_indptr=o_indptr, | ||
| ) | ||
|
|
||
| k_scale, v_scale = None, None | ||
| if q_dtype in [torch.float8_e4m3fn, torch.float8_e5m2]: | ||
|
|
@@ -1639,6 +1678,10 @@ def run_backend_wrapper( | |
| if backend in ["cutlass", "fa2", "fa3", "trtllm-gen"]: | ||
| return backend_wrappers[backend].run_return_lse(q, k, v)[0] | ||
| elif backend == "cudnn": | ||
| # cuDNN uses wrapper API | ||
| return backend_wrappers[backend].run(q, k, v) | ||
| elif backend == "cudnn-native": | ||
| # Direct cudnn_batch_prefill_with_kv_cache call | ||
| return flashinfer.prefill.cudnn_batch_prefill_with_kv_cache( | ||
| q, | ||
| k, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2478,7 +2478,7 @@ def __init__( | |
| will be used in attention computation. | ||
|
|
||
| backend : str | ||
| The implementation backend, could be ``auto``/``fa2``/``fa3`` or ``cutlass``. | ||
| The implementation backend, could be ``auto``/``fa2``/``fa3``/``cudnn`` or ``cutlass``. | ||
| Defaults to ``auto``. | ||
| If set to ``auto``, the wrapper will automatically choose the backend based on the | ||
| device architecture and kernel availability. | ||
|
|
@@ -2598,6 +2598,12 @@ def plan( | |
| max_item_len_ptr: Optional[torch.Tensor] = None, | ||
| fixed_split_size: Optional[int] = None, | ||
| disable_split_kv: bool = False, | ||
| seq_lens: Optional[torch.Tensor] = None, | ||
|
Anerudhan marked this conversation as resolved.
|
||
| seq_lens_q: Optional[torch.Tensor] = None, | ||
| max_token_per_sequence: Optional[int] = None, | ||
| max_sequence_kv: Optional[int] = None, | ||
| v_indptr: Optional[torch.Tensor] = None, | ||
| o_indptr: Optional[torch.Tensor] = None, | ||
|
Comment on lines
+2601
to
+2606
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. π§© Analysis chainπ Script executed: git ls-files flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer Length of output: 89 π Script executed: # Read lines around 2601-2606 and surrounding context
sed -n '2590,2620p' flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer Length of output: 1480 π Script executed: # Read lines around 2824-2828 and surrounding context
sed -n '2810,2840p' flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer Length of output: 1268 π Script executed: # Search for where seq_lens_q is first used in function definitions with docstrings
rg -A 20 "seq_lens_q" flashinfer/prefill.py | head -100Repository: flashinfer-ai/flashinfer Length of output: 3474 π Script executed: # Let's look at the function containing lines 2601-2606
sed -n '2570,2640p' flashinfer/prefill.py | cat -nRepository: flashinfer-ai/flashinfer Length of output: 3802 π Script executed: #!/bin/bash
# Find all occurrences of seq_lens_q assignment and read context
rg -B 5 -A 5 "self\._seq_lens_q\s*=" flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer Length of output: 2286 π Script executed: # Search for seq_lens_q in docstrings or comments
rg "seq_lens_q" flashinfer/prefill.py -B 2 -A 2 | head -150Repository: flashinfer-ai/flashinfer Length of output: 5119 π Script executed: #!/bin/bash
# Get line count first
total_lines=$(wc -l < flashinfer/prefill.py)
echo "Total lines: $total_lines"
# Read the section around 2601-2606
echo "=== Lines 2590-2620 ==="
sed -n '2590,2620p' flashinfer/prefill.py
echo ""
echo "=== Lines 2810-2840 ==="
sed -n '2810,2840p' flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer Length of output: 2748 π Script executed: #!/bin/bash
# Find the method definitions around the problem areas
echo "=== Around line 2600 (first occurrence) ==="
sed -n '2565,2580p' flashinfer/prefill.py
echo ""
echo "=== Around line 2800 (second occurrence) ==="
sed -n '2790,2810p' flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer Length of output: 1607 π Script executed: #!/bin/bash
# Find the seq_lens_q documentation
echo "=== Docstring for seq_lens_q around line 2671 ==="
sed -n '2665,2680p' flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer Length of output: 1194 π Script executed: #!/bin/bash
# Search for class definitions in the file
rg "^class " flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer Length of output: 154 π Script executed: #!/bin/bash
# Count occurrences of the problematic pattern
echo "=== Pattern 1: self._seq_lens_q = seq_lens_q (without fallback) ==="
rg -n "self\._seq_lens_q = seq_lens_q$" flashinfer/prefill.py
echo ""
echo "=== Pattern 2: self._seq_lens_q = seq_lens_q if (with fallback) ==="
rg -n "self\._seq_lens_q = seq_lens_q if" flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer Length of output: 330 π Script executed: #!/bin/bash
# Find the class definitions and their line numbers
echo "=== Class definitions with line numbers ==="
rg -n "^class " flashinfer/prefill.py
echo ""
echo "=== Location of line 1988 assignment ==="
sed -n '1980,1995p' flashinfer/prefill.py | cat -n
echo ""
echo "=== Location of line 2824 assignment ==="
sed -n '2820,2830p' flashinfer/prefill.py | cat -nRepository: flashinfer-ai/flashinfer Length of output: 1522 π Script executed: #!/bin/bash
# Search for seq_lens_q in docstrings specifically mentioning the default
rg -B 2 -A 2 "seq_lens_q.*will be set to" flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer Length of output: 50 π Script executed: #!/bin/bash
# Read the full docstring section describing seq_lens_q
sed -n '2620,2720p' flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer Length of output: 7077 π Script executed: #!/bin/bash
# Find which method line 1988 belongs to
echo "=== Method containing line 1988 ==="
sed -n '1900,1990p' flashinfer/prefill.py | grep -E "(def |class )" | tail -1
echo ""
echo "=== Method containing line 2824 ==="
sed -n '2550,2824p' flashinfer/prefill.py | grep -E "(def |class )" | tail -1
echo ""
echo "=== Verify line 2824 is in plan method ==="
sed -n '2570,2580p' flashinfer/prefill.pyRepository: flashinfer-ai/flashinfer Length of output: 446 Set The docstring states that Suggested fix- self._seq_lens_q = seq_lens_q
self._seq_lens_kv = seq_lens
+ self._seq_lens_q = seq_lens_q if seq_lens_q is not None else seq_lensπ€ Prompt for AI Agents |
||
| ) -> None: | ||
| r"""Plan batch prefill/append attention on Ragged KV-Cache for given problem specification. | ||
|
|
||
|
|
@@ -2692,6 +2698,19 @@ def plan( | |
| and lead to a varied number of launched CTAs. | ||
| disable_split_kv : bool, | ||
| Whether to disable the split-kv for determinism in CUDA Graph, defaults to ``False``. | ||
| seq_lens: Optional[torch.Tensor] | ||
| A uint32 1D tensor indicating the kv sequence length of each prompt. shape: ``[batch_size]``. | ||
| seq_lens_q: Optional[torch.Tensor] | ||
| A uint32 1D tensor indicating the q sequence length of each prompt. shape: ``[batch_size]``. | ||
| If not provided, will be set to the same value as ``seq_lens``. | ||
| max_token_per_sequence: Optional[int], | ||
| Required for cudnn backend. This is the scalar max token length of each sequence. | ||
| max_sequence_kv: Optional[int], | ||
| Required for cudnn backend. This is the scalar max sequence length of each sequence in kv cache. | ||
| v_indptr: Optional[torch.Tensor] | ||
| Required for cudnn backend. This is the indptr of the value tensor. | ||
| o_indptr: Optional[torch.Tensor] | ||
| Required for cudnn backend. This is the indptr of the output tensor. | ||
| Note | ||
| ---- | ||
| The :meth:`plan` method should be called before any :meth:`run` or | ||
|
|
@@ -2781,6 +2800,17 @@ def plan( | |
| self.device, non_blocking=non_blocking | ||
| ) | ||
|
|
||
| self._o_indptr_buf = ( | ||
| o_indptr.to(self.device, non_blocking=non_blocking) | ||
| if o_indptr is not None | ||
| else self._qo_indptr_buf | ||
| ) | ||
| self._v_indptr_buf = ( | ||
| v_indptr.to(self.device, non_blocking=non_blocking) | ||
| if v_indptr is not None | ||
| else self._kv_indptr_buf | ||
| ) | ||
|
|
||
| self._cached_q_data_type = q_data_type | ||
| self._cached_kv_data_type = kv_data_type | ||
| self._cached_o_data_type = o_data_type | ||
|
|
@@ -2791,6 +2821,11 @@ def plan( | |
| self._token_pos_in_items_len = token_pos_in_items_len | ||
| self._max_item_len_ptr = max_item_len_ptr | ||
|
|
||
| self._seq_lens_q = seq_lens_q | ||
| self._seq_lens_kv = seq_lens | ||
| self._max_token_per_sequence = max_token_per_sequence | ||
| self._max_sequence_kv = max_sequence_kv | ||
|
|
||
| if self._jit_module is not None: | ||
| self._cached_module = self._jit_module | ||
| else: | ||
|
|
@@ -2822,7 +2857,7 @@ def plan( | |
| get_module_args[:9] + (qo_indptr.device,) + get_module_args[9:] | ||
| ) | ||
| self._cached_module = get_fmha_module(*new_get_module_args) | ||
| else: | ||
| elif self._backend != "cudnn": | ||
| self._cached_module = get_batch_prefill_module( | ||
| self._backend, *get_module_args | ||
| ) | ||
|
|
@@ -2832,7 +2867,7 @@ def plan( | |
| self._cached_module, qo_indptr, kv_indptr, num_qo_heads, causal | ||
| ) | ||
| self._max_qo_len = torch.max(qo_indptr[1:] - qo_indptr[:-1]).item() | ||
| else: | ||
| elif self._backend != "cudnn": | ||
| assert self._cached_module is not None, "cached module is not initialized" | ||
| args = [ | ||
| self._float_workspace_buffer, | ||
|
|
@@ -3040,6 +3075,40 @@ def run( | |
| lse=lse, | ||
| ) | ||
| return (out, lse) if return_lse else out | ||
| elif self._backend == "cudnn": | ||
| if self._seq_lens_q.dim() == 1: | ||
| batch_size = self._seq_lens_q.shape[0] | ||
| if self._seq_lens_q is not None and self._seq_lens_q.dim() == 1: | ||
| self._seq_lens_q = self._seq_lens_q.reshape(batch_size, 1, 1, 1) | ||
|
|
||
| if self._seq_lens_kv is not None and self._seq_lens_kv.dim() == 1: | ||
| self._seq_lens_kv = self._seq_lens_kv.reshape(batch_size, 1, 1, 1) | ||
|
Anerudhan marked this conversation as resolved.
|
||
|
|
||
| cudnn_batch_prefill_with_kv_cache( | ||
| q, | ||
| k, | ||
| v, | ||
| sm_scale, | ||
| self._float_workspace_buffer, | ||
| max_token_per_sequence=self._max_token_per_sequence, | ||
| max_sequence_kv=self._max_sequence_kv, | ||
| actual_seq_lens_q=self._seq_lens_q, | ||
| actual_seq_lens_kv=self._seq_lens_kv, | ||
| return_lse=return_lse, | ||
| causal=self._causal, | ||
| q_scale=q_scale, | ||
| k_scale=k_scale, | ||
| v_scale=v_scale, | ||
| batch_offsets_q=self._qo_indptr_buf, | ||
| batch_offsets_k=self._kv_indptr_buf, | ||
| batch_offsets_v=self._v_indptr_buf, | ||
| batch_offsets_o=self._o_indptr_buf, | ||
| is_cuda_graph_compatible=self._use_cuda_graph, | ||
| out=out, | ||
| lse=lse, | ||
| ) | ||
|
|
||
| return (out, lse) if return_lse else out | ||
|
Anerudhan marked this conversation as resolved.
Anerudhan marked this conversation as resolved.
|
||
|
|
||
| # Skip FP8->FP16 conversion for FA3 backend with FP8 support | ||
| # The JIT module will handle FP8 natively | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -5,22 +5,23 @@ | |||||||||
|
|
||||||||||
|
|
||||||||||
| @pytest.mark.parametrize("batch_size", [1, 4]) | ||||||||||
| @pytest.mark.parametrize("s_qo", [32, 64, 87]) | ||||||||||
| @pytest.mark.parametrize("s_kv", [32, 64, 87]) | ||||||||||
| @pytest.mark.parametrize("num_kv_heads", [1]) | ||||||||||
| @pytest.mark.parametrize("num_qo_heads", [1, 16]) | ||||||||||
| @pytest.mark.parametrize("s_qo", [32, 64, 87, 256]) | ||||||||||
| @pytest.mark.parametrize("s_kv", [32, 87, 512]) | ||||||||||
| @pytest.mark.parametrize("num_kv_heads", [1, 4]) | ||||||||||
| @pytest.mark.parametrize("num_qo_heads", [1, 8]) | ||||||||||
| @pytest.mark.parametrize("causal", [True, False]) | ||||||||||
| def test_cudnn_prefill_deepseek( | ||||||||||
| batch_size, s_qo, s_kv, num_kv_heads, num_qo_heads, causal | ||||||||||
| ): | ||||||||||
| if s_qo > s_kv: | ||||||||||
| pytest.skip("s_qo > s_kv, skipping test as causal") | ||||||||||
|
|
||||||||||
| if num_qo_heads < num_kv_heads: | ||||||||||
| pytest.skip("num_qo_heads < num_kv_heads, skipping test") | ||||||||||
|
Comment on lines
7
to
+20
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. π§© Analysis chainπ Script executed: # First, find the test file and read it
cat -n tests/attention/test_cudnn_prefill_deepseek.pyRepository: flashinfer-ai/flashinfer Length of output: 5877 π Script executed: # Check what's available in flashinfer.utils
find . -name "utils.py" -o -name "utils" -type d | grep flashinfer | head -20Repository: flashinfer-ai/flashinfer Length of output: 301 π Script executed: # Look for get_compute_capability and related functions in flashinfer
rg "get_compute_capability|is_sm90a_supported|is_sm100a_supported" --type pyRepository: flashinfer-ai/flashinfer Length of output: 17993 π Script executed: cat -n tests/attention/test_cudnn_prefill.py | head -50Repository: flashinfer-ai/flashinfer Length of output: 1740 π Script executed: # Check flashinfer.utils for cuDNN-specific checks
rg "cudnn|get_compute_capability" flashinfer/utils.py -A 3 -B 1Repository: flashinfer-ai/flashinfer Length of output: 2283 π Script executed: # Check if there's documentation or comments about cuDNN GPU requirements
rg "cuDNN|cudnn.*sm|cudnn.*arch|cudnn.*compute" --type py -i -B 2 -A 2 | head -80Repository: flashinfer-ai/flashinfer Length of output: 4924 π Script executed: # Check conftest to see if there are global skip mechanisms for cuDNN
cat -n tests/conftest.py 2>/dev/null | head -100Repository: flashinfer-ai/flashinfer Length of output: 4344 π Script executed: # Continue reading conftest
tail -n +100 tests/conftest.py | head -80Repository: flashinfer-ai/flashinfer Length of output: 3055 π Script executed: # Look for OOM skip mechanisms mentioned in the coding guidelines
rg "OOM|out.of.memory|OutOfMemory" --type py tests/Repository: flashinfer-ai/flashinfer Length of output: 451 π Script executed: # Check if there are any SM requirements documented in flashinfer for cuDNN prefill
rg "cudnn.*prefill|prefill.*cudnn" --type py -B 3 -A 3Repository: flashinfer-ai/flashinfer Length of output: 6803 π Script executed: # Check what compute capabilities are actually needed by looking at BatchPrefillWithRaggedKVCacheWrapper
rg "BatchPrefillWithRaggedKVCacheWrapper" --type py -lRepository: flashinfer-ai/flashinfer Length of output: 836 π Script executed: # Check the actual prefill.py to see if there are SM requirements
grep -n "cudnn\|backend" flashinfer/prefill.py | head -30Repository: flashinfer-ai/flashinfer Length of output: 1608 π Script executed: # Read the full test_cudnn_prefill.py to see what GPU checks it has
cat -n tests/attention/test_cudnn_prefill.py | grep -A 30 "def test_cudnn_prefill"Repository: flashinfer-ai/flashinfer Length of output: 1982 π Script executed: # Check flashinfer/cudnn/__init__.py to understand cuDNN availability
cat -n flashinfer/cudnn/__init__.pyRepository: flashinfer-ai/flashinfer Length of output: 189 π Script executed: # Check if there's a CUDNN_AVAILABLE flag or similar
rg "CUDNN_AVAILABLE" --type pyRepository: flashinfer-ai/flashinfer Length of output: 1417 Add GPU capability and cuDNN availability checks to gate test execution. This test uses the cuDNN backend explicitly but lacks guards for GPU architecture support and cuDNN availability. As per coding guidelines, tests must skip on unsupported hardware. The 512MB workspace allocation can also cause OOM on smaller GPUs. Add checks before tensor allocations. π§ͺ Suggested skip guards import pytest
import torch
import flashinfer
+from flashinfer.utils import get_compute_capability
+
+try:
+ import cudnn # type: ignore
+ CUDNN_AVAILABLE = True
+except (ImportError, OSError):
+ CUDNN_AVAILABLE = False
`@pytest.mark.parametrize`("batch_size", [1, 4])
@@ -26,6 +34,14 @@ def test_cudnn_prefill_deepseek(
batch_size, s_qo, s_kv, num_kv_heads, num_qo_heads, causal
):
if s_qo > s_kv:
pytest.skip("s_qo > s_kv, skipping test as causal")
if num_qo_heads < num_kv_heads:
pytest.skip("num_qo_heads < num_kv_heads, skipping test")
+
+ if not CUDNN_AVAILABLE:
+ pytest.skip("cuDNN not available")
+ major, _ = get_compute_capability(torch.device("cuda:0"))
+ if major < 8:
+ pytest.skip("cuDNN prefill requires SM80+")π€ Prompt for AI Agents |
||||||||||
|
|
||||||||||
| head_dim_qk = 192 | ||||||||||
| head_dim_vo = 128 | ||||||||||
|
|
||||||||||
| return_lse = True | ||||||||||
|
|
||||||||||
| # test set up basics | ||||||||||
| seed = 0 | ||||||||||
| torch.manual_seed(seed) | ||||||||||
|
|
@@ -76,14 +77,14 @@ def test_cudnn_prefill_deepseek( | |||||||||
| ] | ||||||||||
| ).int() | ||||||||||
|
|
||||||||||
| batch_offsets_stats = torch.cat( | ||||||||||
| [ | ||||||||||
| torch.zeros( | ||||||||||
| 1, device=actual_seq_lens_q.device, dtype=actual_seq_lens_q.dtype | ||||||||||
| ), | ||||||||||
| torch.cumsum(actual_seq_lens_q.flatten(), dim=0) * num_qo_heads, | ||||||||||
| ] | ||||||||||
| ).cuda() | ||||||||||
| # batch_offsets_stats = torch.cat( | ||||||||||
| # [ | ||||||||||
| # torch.zeros( | ||||||||||
| # 1, device=actual_seq_lens_q.device, dtype=actual_seq_lens_q.dtype | ||||||||||
| # ), | ||||||||||
| # torch.cumsum(actual_seq_lens_q.flatten(), dim=0) * num_qo_heads, | ||||||||||
| # ] | ||||||||||
| # ).cuda() | ||||||||||
|
|
||||||||||
| k_cache = torch.randn( | ||||||||||
| batch_size * s_kv, | ||||||||||
|
|
@@ -103,45 +104,45 @@ def test_cudnn_prefill_deepseek( | |||||||||
| # Initialize scale | ||||||||||
| scale = float(1.0 / (head_dim_qk**0.5)) | ||||||||||
|
|
||||||||||
| workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8, device=device) | ||||||||||
| workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=device) | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right-size the workspace buffer to reduce OOM risk. Hardcoding a 512MB workspace can exhaust memory on smaller GPUs. Consider capping it relative to device memory. As per coding guidelines, avoid OOM-prone test sizes. π‘ Safer workspace sizing- workspace_buffer = torch.empty(512 * 1024 * 1024, dtype=torch.int8, device=device)
+ total_mem = torch.cuda.get_device_properties(device).total_memory
+ workspace_bytes = min(512 * 1024 * 1024, total_mem // 8)
+ workspace_buffer = torch.empty(workspace_bytes, dtype=torch.int8, device=device)π Committable suggestion
Suggested change
π€ Prompt for AI Agents |
||||||||||
|
|
||||||||||
| # output = torch.zeros_like(q) | ||||||||||
| output, lse = flashinfer.prefill.cudnn_batch_prefill_with_kv_cache( | ||||||||||
| q, | ||||||||||
| k_cache, | ||||||||||
| v_cache, | ||||||||||
| scale, | ||||||||||
| workspace_buffer, | ||||||||||
| wrapper_cudnn = flashinfer.BatchPrefillWithRaggedKVCacheWrapper( | ||||||||||
| workspace_buffer, "NHD", backend="cudnn" | ||||||||||
| ) | ||||||||||
|
|
||||||||||
| wrapper_cudnn.plan( | ||||||||||
| qo_indptr=q_indptr, | ||||||||||
| kv_indptr=k_indptr, | ||||||||||
| num_qo_heads=num_qo_heads, | ||||||||||
| num_kv_heads=num_kv_heads, | ||||||||||
| head_dim_qk=head_dim_qk, | ||||||||||
| head_dim_vo=head_dim_vo, | ||||||||||
| causal=causal, | ||||||||||
| sm_scale=scale, | ||||||||||
| q_data_type=torch.bfloat16, | ||||||||||
| kv_data_type=torch.bfloat16, | ||||||||||
| o_data_type=torch.bfloat16, | ||||||||||
| seq_lens=actual_seq_lens_kv, | ||||||||||
| seq_lens_q=actual_seq_lens_q, | ||||||||||
|
Anerudhan marked this conversation as resolved.
|
||||||||||
| max_token_per_sequence=s_qo, | ||||||||||
| max_sequence_kv=s_kv, | ||||||||||
| actual_seq_lens_q=actual_seq_lens_q, | ||||||||||
| actual_seq_lens_kv=actual_seq_lens_kv, | ||||||||||
| causal=causal, | ||||||||||
| return_lse=return_lse, | ||||||||||
| batch_offsets_q=q_indptr, | ||||||||||
| batch_offsets_k=k_indptr, | ||||||||||
| batch_offsets_v=v_indptr, | ||||||||||
| batch_offsets_o=o_indptr, | ||||||||||
| batch_offsets_stats=batch_offsets_stats, | ||||||||||
| is_cuda_graph_compatible=True, | ||||||||||
| v_indptr=v_indptr, | ||||||||||
| o_indptr=o_indptr, | ||||||||||
| ) | ||||||||||
|
Anerudhan marked this conversation as resolved.
|
||||||||||
|
|
||||||||||
| output = wrapper_cudnn.run(q, k_cache, v_cache) | ||||||||||
|
|
||||||||||
| qo_indptr = torch.cat( | ||||||||||
| [ | ||||||||||
| torch.tensor([0], device=device), | ||||||||||
| torch.cumsum(actual_seq_lens_q.view(-1), dim=0), | ||||||||||
| ] | ||||||||||
| ).int() | ||||||||||
|
|
||||||||||
| # kv_indptr = torch.arange(0, batch_size + 1, device="cuda", dtype=torch.int32) * s_kv | ||||||||||
|
|
||||||||||
| # Create kv_indptr as cumulative sum of actual_seq_lens_kv | ||||||||||
| kv_indptr = torch.cat( | ||||||||||
| [ | ||||||||||
| torch.tensor( | ||||||||||
| [0], | ||||||||||
| device=device, | ||||||||||
| ), | ||||||||||
| torch.tensor([0], device=device), | ||||||||||
| torch.cumsum(actual_seq_lens_kv.view(-1), dim=0), | ||||||||||
| ] | ||||||||||
| ).int() | ||||||||||
|
|
||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix unordered list indentation to satisfy MD007.
Markdownlint flags inconsistent indentation on this new sub-bullet; align it with the other βAlso supportsβ entries.
π§Ή Markdownlint fix
π Committable suggestion
π§° Tools
πͺ markdownlint-cli2 (0.18.1)
19-19: Unordered list indentation
Expected: 4; Actual: 8
(MD007, ul-indent)
π€ Prompt for AI Agents