Skip to content

[Fix] Cast weight dtype in sgl-kernel norm wrappers for flashinfer 0.6.7#21645

Closed
Fridge003 wants to merge 1 commit intomainfrom
fix/flashinfer-rmsnorm-dtype-cast
Closed

[Fix] Cast weight dtype in sgl-kernel norm wrappers for flashinfer 0.6.7#21645
Fridge003 wants to merge 1 commit intomainfrom
fix/flashinfer-rmsnorm-dtype-cast

Conversation

@Fridge003
Copy link
Copy Markdown
Collaborator

Summary

  • Flashinfer 0.6.7 switched its rmsnorm implementation to CuTe-based kernels via TVM FFI, which enforce strict dtype matching between input and weight tensors
  • When RMSNorm weight is fp32 but input is bf16/fp16 (common in diffusion models), the new kernels raise ValueError: Mismatched Tensor on argument #1
  • This fix casts weight to match input dtype before calling flashinfer norm functions in all 4 affected wrappers: rmsnorm, fused_add_rmsnorm, gemma_rmsnorm, gemma_fused_add_rmsnorm

Fixes the CI failure in PR #21422 (bench_fused_norm_scale_shift.py).

Test plan

  • Reproduced the bug on H200 with flashinfer 0.6.7
  • Verified fix resolves the dtype mismatch error
  • Verified numerical correctness (exact match when using same-dtype reference)
  • CI should pass stage-b-kernel-benchmark-1-gpu-large suite

🤖 Generated with Claude Code

…6.7 compatibility

Flashinfer 0.6.7 switched to CuTe-based kernels with stricter dtype
validation for rmsnorm. When weight (fp32) and input (bf16/fp16) dtypes
mismatch, the new kernels raise ValueError. Cast weight to input dtype
before calling flashinfer norm functions.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces automatic dtype casting for the weight tensor in the rmsnorm, fused_add_rmsnorm, gemma_rmsnorm, and gemma_fused_add_rmsnorm functions to ensure it matches the input tensor's dtype before calling flashinfer kernels. A review comment suggests refactoring this repeated casting logic into a helper function to reduce code duplication and improve maintainability.

Comment on lines +123 to +124
if weight.dtype != input.dtype:
weight = weight.to(input.dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This dtype casting logic is repeated in fused_add_rmsnorm, gemma_rmsnorm, and gemma_fused_add_rmsnorm. To improve maintainability and reduce code duplication, you could extract this logic into a helper function.

First, define the helper function, for example at the top of the file:

def _maybe_cast_weight(weight: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
    """Casts weight to match input's dtype if they differ."""
    if weight.dtype != input_tensor.dtype:
        return weight.to(input_tensor.dtype)
    return weight

Then, you can use this helper in the four norm functions. The suggestion below shows how to apply it for rmsnorm. The same change can be applied to the other three functions.

        weight = _maybe_cast_weight(weight, input)

@Fridge003
Copy link
Copy Markdown
Collaborator Author

Cherry-picked by #21422

@Fridge003 Fridge003 closed this Mar 30, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant