diff --git a/flashinfer/norm/kernels/fused_add_rmsnorm.py b/flashinfer/norm/kernels/fused_add_rmsnorm.py index 31e3840d4a..be5917e614 100644 --- a/flashinfer/norm/kernels/fused_add_rmsnorm.py +++ b/flashinfer/norm/kernels/fused_add_rmsnorm.py @@ -902,8 +902,8 @@ def _get_compiled_fused_add_rmsnorm_kernel( dtype, (sym_m, H), stride_order=(1, 0), assumed_align=tensor_align ) else: - sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) - sym_row_stride_r = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_row_stride_x = cute.sym_int64(divisibility=kernel_obj.vec_size) + sym_row_stride_r = cute.sym_int64(divisibility=kernel_obj.vec_size) x_fake = cute.runtime.make_fake_tensor( dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16 ) @@ -966,9 +966,9 @@ def _get_compiled_fused_add_rmsnorm_quant_kernel( dtype, (sym_m, H), stride_order=(1, 0), assumed_align=in_align ) else: - sym_row_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size) - sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) - sym_row_stride_r = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_row_stride_y = cute.sym_int64(divisibility=kernel_obj.vec_size) + sym_row_stride_x = cute.sym_int64(divisibility=kernel_obj.vec_size) + sym_row_stride_r = cute.sym_int64(divisibility=kernel_obj.vec_size) y_fake = cute.runtime.make_fake_tensor( out_dtype, (sym_m, H), (sym_row_stride_y, 1), assumed_align=16 ) diff --git a/flashinfer/norm/kernels/rmsnorm.py b/flashinfer/norm/kernels/rmsnorm.py index a47a980036..4cf55611a0 100644 --- a/flashinfer/norm/kernels/rmsnorm.py +++ b/flashinfer/norm/kernels/rmsnorm.py @@ -1127,8 +1127,8 @@ def _get_compiled_rmsnorm_kernel( dtype, (sym_m, H), stride_order=(1, 0), assumed_align=tensor_align ) else: - sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) - sym_row_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_row_stride_x = cute.sym_int64(divisibility=kernel_obj.vec_size) + sym_row_stride_y = cute.sym_int64(divisibility=kernel_obj.vec_size) x_fake = cute.runtime.make_fake_tensor( dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16 ) @@ -1168,10 +1168,10 @@ def _get_compiled_qk_rmsnorm_kernel( # Stride divisibility = vec_size guarantees each row start is aligned # for the chosen copy_bits (e.g. vec_size=8 for fp16 → 16-byte aligned). - sym_batch_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) - sym_head_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) - sym_batch_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size) - sym_head_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_batch_stride_x = cute.sym_int64(divisibility=kernel_obj.vec_size) + sym_head_stride_x = cute.sym_int64(divisibility=kernel_obj.vec_size) + sym_batch_stride_y = cute.sym_int64(divisibility=kernel_obj.vec_size) + sym_head_stride_y = cute.sym_int64(divisibility=kernel_obj.vec_size) x_fake = cute.runtime.make_fake_tensor( dtype, @@ -1238,8 +1238,8 @@ def _get_compiled_rmsnorm_quant_kernel( out_dtype, (sym_m, H), stride_order=(1, 0), assumed_align=out_align ) else: - sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size) - sym_row_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size) + sym_row_stride_x = cute.sym_int64(divisibility=kernel_obj.vec_size) + sym_row_stride_y = cute.sym_int64(divisibility=kernel_obj.vec_size) x_fake = cute.runtime.make_fake_tensor( dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16 ) diff --git a/tests/utils/test_norm.py b/tests/utils/test_norm.py index 1bda5005a2..a69a3ec018 100644 --- a/tests/utils/test_norm.py +++ b/tests/utils/test_norm.py @@ -347,6 +347,130 @@ def test_layernorm(batch_size, hidden_size, dtype): torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2) +# ============================================================================= +# Regression tests for int32 stride overflow +# ============================================================================= +# These tests verify that rmsnorm kernels accept tensors with strides exceeding +# INT32_MAX. The 2D tests use M=2 so that is_contiguous() returns False and the +# non-contiguous kernel path (which uses sym_int64 strides) is exercised. This +# requires a ~4 GB flat buffer so the large stride is actually traversable. +# The 3D qknorm test can use batch=1 because qk_rmsnorm_cute always uses +# symbolic strides regardless of contiguity. + +_INT64_STRIDE = 2**31 # just above INT32_MAX = 2**31 - 1 +_STRIDE_BUF_BYTES = (_INT64_STRIDE + 128) * 2 # bf16, H=128 + + +def _skip_if_low_vram(): + free, _ = torch.cuda.mem_get_info() + if free < _STRIDE_BUF_BYTES * 1.2: + pytest.skip( + f"Requires ~{_STRIDE_BUF_BYTES / 1024**3:.1f}GB free VRAM, " + f"only {free / 1024**3:.1f}GB available" + ) + + +def test_rmsnorm_int64_stride(): + """2D rmsnorm with row stride > INT32_MAX (issue #3005).""" + _skip_if_low_vram() + H = 128 + dtype = torch.bfloat16 + buf = torch.randn(_INT64_STRIDE + H, dtype=dtype, device="cuda") + w = torch.randn(H, dtype=dtype, device="cuda") + + x = torch.as_strided(buf, (2, H), (_INT64_STRIDE, 1)) + assert not x.is_contiguous() + y = flashinfer.norm.rmsnorm(x, w) + + y_ref = llama_rms_norm(x.contiguous(), w) + torch.testing.assert_close(y, y_ref, rtol=1e-3, atol=1e-3) + + +def test_qknorm_int64_stride(): + """3D qk_rmsnorm with batch stride > INT32_MAX (issue #3005). + + qk_rmsnorm_cute always uses symbolic strides (no contiguity check), + so batch=1 suffices — the large stride is validated by TVM-FFI even + though it is never traversed. + """ + num_heads, head_dim = 4, 128 + dtype = torch.bfloat16 + buf = torch.randn(1, num_heads, head_dim, dtype=dtype, device="cuda") + w = torch.randn(head_dim, dtype=dtype, device="cuda") + + x = torch.as_strided(buf, (1, num_heads, head_dim), (_INT64_STRIDE, head_dim, 1)) + y = flashinfer.norm.rmsnorm(x, w) + + y_ref = llama_rms_norm(buf, w) + torch.testing.assert_close(y, y_ref, rtol=1e-3, atol=1e-3) + + +def test_rmsnorm_quant_int64_stride(): + """rmsnorm_quant with row stride > INT32_MAX (issue #3005).""" + _skip_if_low_vram() + H = 128 + dtype = torch.bfloat16 + quant_scale = 1.0 + buf = torch.randn(_INT64_STRIDE + H, dtype=dtype, device="cuda") + w = torch.randn(H, dtype=dtype, device="cuda") + + x = torch.as_strided(buf, (2, H), (_INT64_STRIDE, 1)) + assert not x.is_contiguous() + y = torch.empty(2, H, dtype=torch.float8_e4m3fn, device="cuda") + flashinfer.norm.rmsnorm_quant(y, x, w, torch.tensor(quant_scale, device="cuda")) + + y_ref = llama_rms_norm_quant(x.contiguous(), w, quant_scale) + torch.testing.assert_close(y.float(), y_ref.float(), rtol=1, atol=1) + + +def test_fused_add_rmsnorm_int64_stride(): + """fused_add_rmsnorm with row stride > INT32_MAX (issue #3005).""" + _skip_if_low_vram() + H = 128 + dtype = torch.bfloat16 + eps = 1e-6 + buf_x = torch.randn(_INT64_STRIDE + H, dtype=dtype, device="cuda") + w = torch.randn(H, dtype=dtype, device="cuda") + # Contiguous residual — only one non-contiguous tensor is needed to + # trigger the non-contiguous kernel path. + r = torch.randn(2, H, dtype=dtype, device="cuda") + + x = torch.as_strided(buf_x, (2, H), (_INT64_STRIDE, 1)) + assert not x.is_contiguous() + x_ref, r_ref = fused_add_rms_norm(x.contiguous().clone(), r.clone(), w, eps) + + flashinfer.fused_add_rmsnorm(x, r, w, eps) + + torch.testing.assert_close(x.contiguous(), x_ref, rtol=1e-3, atol=1e-3) + torch.testing.assert_close(r, r_ref, rtol=1e-3, atol=1e-3) + + +def test_fused_add_rmsnorm_quant_int64_stride(): + """fused_add_rmsnorm_quant with row stride > INT32_MAX (issue #3005).""" + _skip_if_low_vram() + H = 128 + dtype = torch.bfloat16 + eps = 1e-6 + quant_scale = 1.0 + buf_x = torch.randn(_INT64_STRIDE + H, dtype=dtype, device="cuda") + w = torch.randn(H, dtype=dtype, device="cuda") + r = torch.randn(2, H, dtype=dtype, device="cuda") + + x = torch.as_strided(buf_x, (2, H), (_INT64_STRIDE, 1)) + assert not x.is_contiguous() + x_ref, r_ref = fused_add_rms_norm_quant( + x.contiguous().clone(), r.clone(), w, quant_scale, eps + ) + + y = torch.empty(2, H, dtype=torch.float8_e4m3fn, device="cuda") + flashinfer.norm.fused_add_rmsnorm_quant( + y, x, r, w, torch.tensor(quant_scale, device="cuda"), eps + ) + + torch.testing.assert_close(y.float(), x_ref.float(), rtol=1, atol=1) + torch.testing.assert_close(r, r_ref, rtol=1e-3, atol=1e-3) + + def test_norm_compilation_without_fp8(): """Test that norm module compiles successfully without ENABLE_FP8 flag.