-
Notifications
You must be signed in to change notification settings - Fork 896
feat(gdn): add padding index guard for bf16 decode kernel #2810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b19cfec
b43e512
aaf9173
e756bf2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -23,7 +23,6 @@ | |
| import torch | ||
| import pytest | ||
|
|
||
| pytestmark = pytest.mark.skip(reason="Temporarily skipped due to CI failures.") | ||
|
|
||
| try: | ||
| from .reference_delta_rule import decode_delta_rule, verify_delta_rule | ||
|
|
@@ -203,6 +202,7 @@ def _test_decode_kernel_pretranspose( | |
| "num_q_heads, num_k_heads, num_v_heads", | ||
| [(16, 16, 32)], | ||
| ) | ||
| @pytest.mark.skip(reason="Temporarily skipped due to CI failures.") | ||
| @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]) | ||
| @pytest.mark.parametrize("dtype", ["bfloat16"]) | ||
| def test_decode_kernel_basic_pretranspose( | ||
|
|
@@ -368,6 +368,7 @@ def _test_decode_kernel_nontranspose( | |
| "num_q_heads, num_k_heads, num_v_heads", | ||
| [(16, 16, 32)], | ||
| ) | ||
| @pytest.mark.skip(reason="Temporarily skipped due to CI failures.") | ||
| @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]) | ||
| @pytest.mark.parametrize("dtype", ["bfloat16"]) | ||
| def test_decode_kernel_basic_nontranspose( | ||
|
|
@@ -512,6 +513,7 @@ def _test_decode_kernel_pretranspose_pool( | |
| @pytest.mark.parametrize("scale", [1.0]) | ||
| @pytest.mark.parametrize("head_size", [128]) | ||
| @pytest.mark.parametrize("num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)]) | ||
| @pytest.mark.skip(reason="Temporarily skipped due to CI failures.") | ||
| @pytest.mark.parametrize("batch_size", [1, 4, 16, 32]) | ||
| @pytest.mark.parametrize("dtype", ["bfloat16"]) | ||
| def test_decode_kernel_pretranspose_pool( | ||
|
|
@@ -768,6 +770,7 @@ def _test_decode_kernel_pretranspose_pool_all_padding( | |
| @pytest.mark.parametrize("scale", [1.0]) | ||
| @pytest.mark.parametrize("head_size", [128]) | ||
| @pytest.mark.parametrize("num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)]) | ||
| @pytest.mark.skip(reason="Temporarily skipped due to CI failures.") | ||
| @pytest.mark.parametrize("batch_size", [1, 4, 8, 32, 127]) | ||
| @pytest.mark.parametrize("dtype", ["bfloat16"]) | ||
| def test_decode_kernel_pretranspose_pool_negative_indices( | ||
|
|
@@ -795,6 +798,7 @@ def test_decode_kernel_pretranspose_pool_negative_indices( | |
| @pytest.mark.parametrize("scale", [1.0]) | ||
| @pytest.mark.parametrize("head_size", [128]) | ||
| @pytest.mark.parametrize("num_q_heads, num_k_heads, num_v_heads", [(16, 16, 32)]) | ||
| @pytest.mark.skip(reason="Temporarily skipped due to CI failures.") | ||
| @pytest.mark.parametrize("batch_size", [1, 4, 16, 32]) | ||
| @pytest.mark.parametrize("dtype", ["bfloat16"]) | ||
| def test_decode_kernel_pretranspose_pool_all_padding( | ||
|
|
@@ -819,6 +823,151 @@ def test_decode_kernel_pretranspose_pool_all_padding( | |
| ) | ||
|
|
||
|
|
||
| # ============================================================================ | ||
| # Test bf16 decode kernel with negative (padding) indices | ||
| # | ||
| # Verifies that the bf16 fast-path kernel handles negative indices correctly | ||
| # via the slot-0 null buffer pattern: negative indices are redirected to slot 0 | ||
| # inside the kernel. Valid slots must produce correct output and updated state; | ||
| # the kernel must not crash. | ||
| # ============================================================================ | ||
|
|
||
|
|
||
| def _test_decode_kernel_bf16_padding_indices( | ||
| batch_size: int, | ||
| num_q_heads: int, | ||
| num_v_heads: int, | ||
| head_size: int, | ||
| scale: float, | ||
| padding_fraction: float = 0.5, | ||
| seed: int = 0, | ||
| ): | ||
| """bf16 kernel with mixed negative/valid indices must not crash and must | ||
| produce correct output and state updates for valid slots.""" | ||
| _skip_if_not_sm90_or_later() | ||
|
|
||
| random.seed(seed) | ||
| torch.random.manual_seed(seed) | ||
| torch.cuda.manual_seed(seed) | ||
|
|
||
| pool_size = batch_size * 2 + 1 # slot 0 = null buffer; real slots start at 1 | ||
| device = torch.device("cuda") | ||
|
|
||
| with device: | ||
| q = torch.randn(batch_size, 1, num_q_heads, head_size, dtype=torch.bfloat16) | ||
| k = torch.nn.functional.normalize( | ||
| torch.randn(batch_size, 1, num_q_heads, head_size, dtype=torch.bfloat16), | ||
| p=2.0, | ||
| dim=-1, | ||
| ) | ||
| v = torch.randn(batch_size, 1, num_v_heads, head_size, dtype=torch.bfloat16) | ||
|
|
||
| A_log = torch.randn(num_v_heads, dtype=torch.float32) * 0.1 | ||
| dt_bias = torch.randn(num_v_heads, dtype=torch.float32) * 0.1 | ||
| a = torch.randn(batch_size, 1, num_v_heads, dtype=torch.bfloat16) * 0.1 | ||
| b = torch.randn(batch_size, 1, num_v_heads, dtype=torch.bfloat16) | ||
|
|
||
| # Slot 0 = null buffer (zeros); real slots start from 1 | ||
| pool = torch.zeros( | ||
| pool_size, num_v_heads, head_size, head_size, dtype=torch.bfloat16 | ||
| ) | ||
| pool[1:] = torch.randn( | ||
| pool_size - 1, num_v_heads, head_size, head_size, dtype=torch.bfloat16 | ||
| ) | ||
|
|
||
| # Build indices: some slots are padding (-1), others map to real slots [1, pool_size) | ||
| indices = torch.arange(1, batch_size + 1, dtype=torch.int32, device=device) | ||
| mask = torch.rand(batch_size, device=device) < padding_fraction | ||
| if batch_size >= 2: | ||
| mask[0] = False # ensure at least one valid | ||
| mask[-1] = True # ensure at least one padding | ||
| indices[mask] = -1 | ||
|
|
||
| valid_mask = indices >= 0 | ||
|
|
||
| # ββ Pool path under test βββββββββββββββββββββββββββββββββββββββββββββββββ | ||
| pool_under_test = pool.clone() | ||
| out_pool, _ = gated_delta_rule_decode_pretranspose( | ||
| q=q, | ||
| k=k, | ||
| v=v, | ||
| state=None, | ||
| A_log=A_log, | ||
| a=a, | ||
| dt_bias=dt_bias, | ||
| b=b, | ||
| scale=scale, | ||
| use_qk_l2norm=True, | ||
| initial_state=pool_under_test, | ||
| initial_state_indices=indices, | ||
| ) | ||
| torch.cuda.synchronize() | ||
|
|
||
| # ββ Direct-state reference for valid slots only ββββββββββββββββββββββββββ | ||
| if valid_mask.any(): | ||
| valid_indices = indices[valid_mask] | ||
| gathered = pool[valid_indices].clone() | ||
| out_direct, updated = gated_delta_rule_decode_pretranspose( | ||
| q=q[valid_mask], | ||
| k=k[valid_mask], | ||
| v=v[valid_mask], | ||
| state=gathered, | ||
| A_log=A_log, | ||
| a=a[valid_mask], | ||
| dt_bias=dt_bias, | ||
| b=b[valid_mask], | ||
| scale=scale, | ||
| use_qk_l2norm=True, | ||
| ) | ||
| atol, rtol = 5e-3, 5e-3 | ||
| torch.testing.assert_close( | ||
| out_pool[valid_mask], out_direct, atol=atol, rtol=rtol | ||
| ) | ||
| torch.testing.assert_close( | ||
| pool_under_test[valid_indices], updated, atol=atol, rtol=rtol | ||
| ) | ||
|
|
||
| # Non-selected real slots (slots 1..pool_size-1 not in valid_indices) must be untouched | ||
| used = indices[valid_mask].to(device) | ||
| unused_mask = torch.ones(pool_size, dtype=torch.bool, device=device) | ||
| unused_mask[used] = False | ||
| unused_mask[0] = False # slot 0 may be modified (null buffer), don't check it | ||
|
Comment on lines
+931
to
+934
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. π§© Analysis chainπ Script executed: #!/bin/bash
# Verify whether int32 tensor indexing is accepted in the current torch runtime.
python - <<'PY'
import torch
dev = "cuda" if torch.cuda.is_available() else "cpu"
mask = torch.ones(8, dtype=torch.bool, device=dev)
idx_i32 = torch.tensor([1, 3], dtype=torch.int32, device=dev)
try:
mask[idx_i32] = False
print("int32_indexing_supported")
except Exception as e:
print("int32_indexing_not_supported")
print(type(e).__name__, str(e))
PYRepository: flashinfer-ai/flashinfer Length of output: 185 π Script executed: # First, locate and read the test file
fd test_decode_delta_rule.py -x wc -l {}Repository: flashinfer-ai/flashinfer Length of output: 110 π Script executed: # Read the specific lines and surrounding context from the test file
fd test_decode_delta_rule.py -x head -n 950 {} | tail -n 50Repository: flashinfer-ai/flashinfer Length of output: 2052 π Web query:
π‘ Result: PyTorch supports boolean mask indexing with tensors of dtype torch.bool. An int32 tensor (torch.int32) is not supported as a boolean mask for indexing; it must be explicitly converted to bool first (e.g., mask.bool). Citations:
π Web query:
π‘ Result: PyTorch tensor indexing with tensor indices (advanced indexing) requires the index tensors to have integer dtype, specifically torch.int64 (torch.long) for reliable behavior. While some smaller integer dtypes like int32, int16, int8 may work in basic cases, they can lead to inconsistent or unexpected results, especially with uint8 which historically was treated as boolean mask. Official documentation for functions like torch.index_select specifies "index (IntTensor or LongTensor)", where IntTensor is torch.int32 and LongTensor is torch.int64, but long is the standard and recommended dtype. The tensor.getitem follows similar conventions internally. Boolean tensors are used for boolean indexing/masking, not integer indexing. For scalar index tensors, there are known issues with uint8 and sometimes smaller signed ints producing wrong shapes. Key evidence from sources: - PyTorch forums discuss why indices must be long dtype [1]. - GitHub issues show uint8/int8/etc. inconsistencies and bugs in indexing [2,3]. - torch.index_select docs: "index (IntTensor or LongTensor)" [17]. - torch.index_add_ docs: index should have dtype torch.int64 or torch.int32 [22]. - tensor_attributes lists integer dtypes, but indexing ops prefer 32/64-bit [4]. Use torch.long indices: idx = torch.tensor([1,2,3], dtype=torch.long); x[idx] Citations:
π Script executed: # Find where 'indices' variable is defined in the test file
rg -n "indices\s*=" tests/gdn/test_decode_delta_rule.py | head -20Repository: flashinfer-ai/flashinfer Length of output: 919 π Script executed: # Look for the indices creation with dtype info
rg -B 5 -A 5 "indices.*int32\|torch\.int32.*indices\|dtype.*int32" tests/gdn/test_decode_delta_rule.pyRepository: flashinfer-ai/flashinfer Length of output: 50 π Script executed: # Find the function name containing lines 931-934
sed -n '880,935p' tests/gdn/test_decode_delta_rule.py | head -60Repository: flashinfer-ai/flashinfer Length of output: 1983 π Script executed: # Check the full function signature and location
rg -B 20 "used = indices\[valid_mask\]\.to\(device\)" tests/gdn/test_decode_delta_rule.py | head -40Repository: flashinfer-ai/flashinfer Length of output: 761 Cast index tensor to Line 933 uses Proposed fix- used = indices[valid_mask].to(device)
+ used = indices[valid_mask].to(device=device, dtype=torch.long)
unused_mask = torch.ones(pool_size, dtype=torch.bool, device=device)
unused_mask[used] = Falseπ€ Prompt for AI Agents |
||
| torch.testing.assert_close( | ||
| pool_under_test[unused_mask], pool[unused_mask], atol=0.0, rtol=0.0 | ||
| ) | ||
|
|
||
| # Slot 0 (null buffer) must have been written by padding slots. | ||
| # Without the kernel-level fix, padding slots do an OOB write to gH[-1] | ||
| # (before the pool base) leaving slot 0 untouched β this assertion catches that. | ||
| if mask.any(): | ||
| assert not torch.equal(pool_under_test[0], pool[0]), ( | ||
| "Slot 0 (null buffer) should have been updated by padding slots; " | ||
| "if it is unchanged the kernel fix is missing" | ||
| ) | ||
|
|
||
| print( | ||
| f"β bf16 padding indices test passed " | ||
| f"(batch={batch_size}, valid={valid_mask.sum().item()}, padding={mask.sum().item()})" | ||
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("scale", [1.0]) | ||
| @pytest.mark.parametrize("head_size", [128]) | ||
| @pytest.mark.parametrize("num_q_heads, num_v_heads", [(16, 32)]) | ||
| @pytest.mark.parametrize("batch_size", [1, 4, 16, 32]) | ||
| def test_decode_kernel_bf16_padding_indices( | ||
| batch_size: int, | ||
| num_q_heads: int, | ||
| num_v_heads: int, | ||
| head_size: int, | ||
| scale: float, | ||
| seed: int = int(os.environ.get("SEED", "0")), | ||
| ): | ||
| _test_decode_kernel_bf16_padding_indices( | ||
| batch_size, num_q_heads, num_v_heads, head_size, scale, seed=seed | ||
| ) | ||
|
|
||
|
|
||
| # ============================================================================ | ||
| # Test verify kernel with MTP version (Multiple Token Processing) | ||
| # Reference: fp32 h state (default). | ||
|
|
@@ -1016,6 +1165,7 @@ def _test_verify_kernel_mtp( | |
| "num_q_heads, num_k_heads, num_v_heads", | ||
| [(16, 16, 32)], | ||
| ) | ||
| @pytest.mark.skip(reason="Temporarily skipped due to CI failures.") | ||
| @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16]) | ||
| @pytest.mark.parametrize("dtype", ["bfloat16"]) | ||
| def test_verify_kernel_mtp( | ||
|
|
@@ -1056,6 +1206,7 @@ def test_verify_kernel_mtp( | |
|
|
||
|
|
||
| @pytest.mark.parametrize("seq_len", [2, 3, 4, 5, 6, 7, 8]) | ||
| @pytest.mark.skip(reason="Temporarily skipped due to CI failures.") | ||
| @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]) | ||
| @pytest.mark.parametrize("dtype", ["bfloat16"]) | ||
| def test_mtp_fp32_state_with_cache_and_state_update( | ||
|
|
@@ -1268,6 +1419,7 @@ def _test_gdn_decode_klast_bf16_state_kernel( | |
| "num_q_heads, num_k_heads, num_v_heads", | ||
| [(16, 16, 32)], | ||
| ) | ||
| @pytest.mark.skip(reason="Temporarily skipped due to CI failures.") | ||
| @pytest.mark.parametrize("batch_size", [1, 2, 4, 8, 16, 32, 64, 128]) | ||
| @pytest.mark.parametrize("dtype", ["bfloat16"]) | ||
| def test_gdn_decode_klast_bf16_state_kernel( | ||
|
|
@@ -1299,6 +1451,7 @@ def test_gdn_decode_klast_bf16_state_kernel( | |
| ) | ||
|
|
||
|
|
||
| @pytest.mark.skip(reason="Temporarily skipped due to CI failures.") | ||
| @pytest.mark.parametrize("seq_len", [1, 2, 3, 4]) | ||
| @pytest.mark.parametrize("batch_size", [1, 2, 4]) | ||
| @pytest.mark.parametrize("head_size", [128]) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we still need this? cc @bkryu (seems it was first introduced in #2600).