From 63a57e82e4f5c053d6fa64c5c806830dada3bc8a Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Sun, 29 Mar 2026 16:43:17 -0700 Subject: [PATCH] [Fix] Cast weight dtype in sgl-kernel norm wrappers for flashinfer 0.6.7 compatibility Flashinfer 0.6.7 switched to CuTe-based kernels with stricter dtype validation for rmsnorm. When weight (fp32) and input (bf16/fp16) dtypes mismatch, the new kernels raise ValueError. Cast weight to input dtype before calling flashinfer norm functions. Co-Authored-By: Claude Opus 4.6 (1M context) --- sgl-kernel/python/sgl_kernel/elementwise.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index 1ed1ae474a79..547948fcda49 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -120,6 +120,8 @@ def rmsnorm( ): return _rmsnorm_internal(input, weight, eps, out, enable_pdl) else: + if weight.dtype != input.dtype: + weight = weight.to(input.dtype) return _flashinfer_norm.rmsnorm(input, weight, eps, out, enable_pdl) @@ -162,6 +164,8 @@ def fused_add_rmsnorm( ): _fused_add_rmsnorm_internal(input, residual, weight, eps, enable_pdl) else: + if weight.dtype != input.dtype: + weight = weight.to(input.dtype) _flashinfer_norm.fused_add_rmsnorm(input, residual, weight, eps, enable_pdl) @@ -205,6 +209,8 @@ def gemma_rmsnorm( ): return _gemma_rmsnorm_internal(input, weight, eps, out, enable_pdl) else: + if weight.dtype != input.dtype: + weight = weight.to(input.dtype) return _flashinfer_norm.gemma_rmsnorm(input, weight, eps, out, enable_pdl) @@ -247,6 +253,8 @@ def gemma_fused_add_rmsnorm( ): _gemma_fused_add_rmsnorm_internal(input, residual, weight, eps, enable_pdl) else: + if weight.dtype != input.dtype: + weight = weight.to(input.dtype) _flashinfer_norm.gemma_fused_add_rmsnorm( input, residual, weight, eps, enable_pdl )