[Fix] Cast weight dtype in sgl-kernel norm wrappers for flashinfer 0.6.7#21645
[Fix] Cast weight dtype in sgl-kernel norm wrappers for flashinfer 0.6.7#21645
Conversation
…6.7 compatibility Flashinfer 0.6.7 switched to CuTe-based kernels with stricter dtype validation for rmsnorm. When weight (fp32) and input (bf16/fp16) dtypes mismatch, the new kernels raise ValueError. Cast weight to input dtype before calling flashinfer norm functions. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Code Review
This pull request introduces automatic dtype casting for the weight tensor in the rmsnorm, fused_add_rmsnorm, gemma_rmsnorm, and gemma_fused_add_rmsnorm functions to ensure it matches the input tensor's dtype before calling flashinfer kernels. A review comment suggests refactoring this repeated casting logic into a helper function to reduce code duplication and improve maintainability.
| if weight.dtype != input.dtype: | ||
| weight = weight.to(input.dtype) |
There was a problem hiding this comment.
This dtype casting logic is repeated in fused_add_rmsnorm, gemma_rmsnorm, and gemma_fused_add_rmsnorm. To improve maintainability and reduce code duplication, you could extract this logic into a helper function.
First, define the helper function, for example at the top of the file:
def _maybe_cast_weight(weight: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
"""Casts weight to match input's dtype if they differ."""
if weight.dtype != input_tensor.dtype:
return weight.to(input_tensor.dtype)
return weightThen, you can use this helper in the four norm functions. The suggestion below shows how to apply it for rmsnorm. The same change can be applied to the other three functions.
weight = _maybe_cast_weight(weight, input)|
Cherry-picked by #21422 |
Summary
rmsnormimplementation to CuTe-based kernels via TVM FFI, which enforce strict dtype matching between input and weight tensorsRMSNormweight is fp32 but input is bf16/fp16 (common in diffusion models), the new kernels raiseValueError: Mismatched Tensor on argument #1rmsnorm,fused_add_rmsnorm,gemma_rmsnorm,gemma_fused_add_rmsnormFixes the CI failure in PR #21422 (
bench_fused_norm_scale_shift.py).Test plan
stage-b-kernel-benchmark-1-gpu-largesuite🤖 Generated with Claude Code