Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions flashinfer/norm/kernels/fused_add_rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -902,8 +902,8 @@ def _get_compiled_fused_add_rmsnorm_kernel(
dtype, (sym_m, H), stride_order=(1, 0), assumed_align=tensor_align
)
else:
sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size)
sym_row_stride_r = cute.sym_int(divisibility=kernel_obj.vec_size)
sym_row_stride_x = cute.sym_int64(divisibility=kernel_obj.vec_size)
sym_row_stride_r = cute.sym_int64(divisibility=kernel_obj.vec_size)
x_fake = cute.runtime.make_fake_tensor(
dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16
)
Expand Down Expand Up @@ -966,9 +966,9 @@ def _get_compiled_fused_add_rmsnorm_quant_kernel(
dtype, (sym_m, H), stride_order=(1, 0), assumed_align=in_align
)
else:
sym_row_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size)
sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size)
sym_row_stride_r = cute.sym_int(divisibility=kernel_obj.vec_size)
sym_row_stride_y = cute.sym_int64(divisibility=kernel_obj.vec_size)
sym_row_stride_x = cute.sym_int64(divisibility=kernel_obj.vec_size)
sym_row_stride_r = cute.sym_int64(divisibility=kernel_obj.vec_size)
y_fake = cute.runtime.make_fake_tensor(
out_dtype, (sym_m, H), (sym_row_stride_y, 1), assumed_align=16
)
Expand Down
16 changes: 8 additions & 8 deletions flashinfer/norm/kernels/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1127,8 +1127,8 @@ def _get_compiled_rmsnorm_kernel(
dtype, (sym_m, H), stride_order=(1, 0), assumed_align=tensor_align
)
else:
sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size)
sym_row_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size)
sym_row_stride_x = cute.sym_int64(divisibility=kernel_obj.vec_size)
sym_row_stride_y = cute.sym_int64(divisibility=kernel_obj.vec_size)
x_fake = cute.runtime.make_fake_tensor(
dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16
)
Expand Down Expand Up @@ -1168,10 +1168,10 @@ def _get_compiled_qk_rmsnorm_kernel(

# Stride divisibility = vec_size guarantees each row start is aligned
# for the chosen copy_bits (e.g. vec_size=8 for fp16 β†’ 16-byte aligned).
sym_batch_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size)
sym_head_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size)
sym_batch_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size)
sym_head_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size)
sym_batch_stride_x = cute.sym_int64(divisibility=kernel_obj.vec_size)
sym_head_stride_x = cute.sym_int64(divisibility=kernel_obj.vec_size)
sym_batch_stride_y = cute.sym_int64(divisibility=kernel_obj.vec_size)
sym_head_stride_y = cute.sym_int64(divisibility=kernel_obj.vec_size)

x_fake = cute.runtime.make_fake_tensor(
dtype,
Expand Down Expand Up @@ -1238,8 +1238,8 @@ def _get_compiled_rmsnorm_quant_kernel(
out_dtype, (sym_m, H), stride_order=(1, 0), assumed_align=out_align
)
else:
sym_row_stride_x = cute.sym_int(divisibility=kernel_obj.vec_size)
sym_row_stride_y = cute.sym_int(divisibility=kernel_obj.vec_size)
sym_row_stride_x = cute.sym_int64(divisibility=kernel_obj.vec_size)
sym_row_stride_y = cute.sym_int64(divisibility=kernel_obj.vec_size)
x_fake = cute.runtime.make_fake_tensor(
dtype, (sym_m, H), (sym_row_stride_x, 1), assumed_align=16
)
Expand Down
124 changes: 124 additions & 0 deletions tests/utils/test_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,130 @@ def test_layernorm(batch_size, hidden_size, dtype):
torch.testing.assert_close(out, out_ref, rtol=1e-2, atol=1e-2)


# =============================================================================
# Regression tests for int32 stride overflow
# =============================================================================
# These tests verify that rmsnorm kernels accept tensors with strides exceeding
# INT32_MAX. The 2D tests use M=2 so that is_contiguous() returns False and the
# non-contiguous kernel path (which uses sym_int64 strides) is exercised. This
# requires a ~4 GB flat buffer so the large stride is actually traversable.
# The 3D qknorm test can use batch=1 because qk_rmsnorm_cute always uses
# symbolic strides regardless of contiguity.

_INT64_STRIDE = 2**31 # just above INT32_MAX = 2**31 - 1
_STRIDE_BUF_BYTES = (_INT64_STRIDE + 128) * 2 # bf16, H=128


def _skip_if_low_vram():
free, _ = torch.cuda.mem_get_info()
if free < _STRIDE_BUF_BYTES * 1.2:
pytest.skip(
f"Requires ~{_STRIDE_BUF_BYTES / 1024**3:.1f}GB free VRAM, "
f"only {free / 1024**3:.1f}GB available"
)


def test_rmsnorm_int64_stride():
"""2D rmsnorm with row stride > INT32_MAX (issue #3005)."""
_skip_if_low_vram()
H = 128
dtype = torch.bfloat16
buf = torch.randn(_INT64_STRIDE + H, dtype=dtype, device="cuda")
w = torch.randn(H, dtype=dtype, device="cuda")

x = torch.as_strided(buf, (2, H), (_INT64_STRIDE, 1))
assert not x.is_contiguous()
y = flashinfer.norm.rmsnorm(x, w)

y_ref = llama_rms_norm(x.contiguous(), w)
torch.testing.assert_close(y, y_ref, rtol=1e-3, atol=1e-3)
Comment thread
coderabbitai[bot] marked this conversation as resolved.


def test_qknorm_int64_stride():
"""3D qk_rmsnorm with batch stride > INT32_MAX (issue #3005).

qk_rmsnorm_cute always uses symbolic strides (no contiguity check),
so batch=1 suffices β€” the large stride is validated by TVM-FFI even
though it is never traversed.
"""
num_heads, head_dim = 4, 128
dtype = torch.bfloat16
buf = torch.randn(1, num_heads, head_dim, dtype=dtype, device="cuda")
w = torch.randn(head_dim, dtype=dtype, device="cuda")

x = torch.as_strided(buf, (1, num_heads, head_dim), (_INT64_STRIDE, head_dim, 1))
y = flashinfer.norm.rmsnorm(x, w)

y_ref = llama_rms_norm(buf, w)
torch.testing.assert_close(y, y_ref, rtol=1e-3, atol=1e-3)


def test_rmsnorm_quant_int64_stride():
"""rmsnorm_quant with row stride > INT32_MAX (issue #3005)."""
_skip_if_low_vram()
H = 128
dtype = torch.bfloat16
quant_scale = 1.0
buf = torch.randn(_INT64_STRIDE + H, dtype=dtype, device="cuda")
w = torch.randn(H, dtype=dtype, device="cuda")

x = torch.as_strided(buf, (2, H), (_INT64_STRIDE, 1))
assert not x.is_contiguous()
y = torch.empty(2, H, dtype=torch.float8_e4m3fn, device="cuda")
flashinfer.norm.rmsnorm_quant(y, x, w, torch.tensor(quant_scale, device="cuda"))

y_ref = llama_rms_norm_quant(x.contiguous(), w, quant_scale)
torch.testing.assert_close(y.float(), y_ref.float(), rtol=1, atol=1)


def test_fused_add_rmsnorm_int64_stride():
"""fused_add_rmsnorm with row stride > INT32_MAX (issue #3005)."""
_skip_if_low_vram()
H = 128
dtype = torch.bfloat16
eps = 1e-6
buf_x = torch.randn(_INT64_STRIDE + H, dtype=dtype, device="cuda")
w = torch.randn(H, dtype=dtype, device="cuda")
# Contiguous residual β€” only one non-contiguous tensor is needed to
# trigger the non-contiguous kernel path.
r = torch.randn(2, H, dtype=dtype, device="cuda")

x = torch.as_strided(buf_x, (2, H), (_INT64_STRIDE, 1))
assert not x.is_contiguous()
x_ref, r_ref = fused_add_rms_norm(x.contiguous().clone(), r.clone(), w, eps)

flashinfer.fused_add_rmsnorm(x, r, w, eps)

torch.testing.assert_close(x.contiguous(), x_ref, rtol=1e-3, atol=1e-3)
torch.testing.assert_close(r, r_ref, rtol=1e-3, atol=1e-3)


def test_fused_add_rmsnorm_quant_int64_stride():
"""fused_add_rmsnorm_quant with row stride > INT32_MAX (issue #3005)."""
_skip_if_low_vram()
H = 128
dtype = torch.bfloat16
eps = 1e-6
quant_scale = 1.0
buf_x = torch.randn(_INT64_STRIDE + H, dtype=dtype, device="cuda")
w = torch.randn(H, dtype=dtype, device="cuda")
r = torch.randn(2, H, dtype=dtype, device="cuda")

x = torch.as_strided(buf_x, (2, H), (_INT64_STRIDE, 1))
assert not x.is_contiguous()
x_ref, r_ref = fused_add_rms_norm_quant(
x.contiguous().clone(), r.clone(), w, quant_scale, eps
)

y = torch.empty(2, H, dtype=torch.float8_e4m3fn, device="cuda")
flashinfer.norm.fused_add_rmsnorm_quant(
y, x, r, w, torch.tensor(quant_scale, device="cuda"), eps
)

torch.testing.assert_close(y.float(), x_ref.float(), rtol=1, atol=1)
torch.testing.assert_close(r, r_ref, rtol=1e-3, atol=1e-3)


def test_norm_compilation_without_fp8():
"""Test that norm module compiles successfully without ENABLE_FP8 flag.

Expand Down
Loading