Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sgl-kernel/python/sgl_kernel/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@ def rmsnorm(
):
return _rmsnorm_internal(input, weight, eps, out, enable_pdl)
else:
if weight.dtype != input.dtype:
weight = weight.to(input.dtype)
Comment on lines +123 to +124
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This dtype casting logic is repeated in fused_add_rmsnorm, gemma_rmsnorm, and gemma_fused_add_rmsnorm. To improve maintainability and reduce code duplication, you could extract this logic into a helper function.

First, define the helper function, for example at the top of the file:

def _maybe_cast_weight(weight: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
    """Casts weight to match input's dtype if they differ."""
    if weight.dtype != input_tensor.dtype:
        return weight.to(input_tensor.dtype)
    return weight

Then, you can use this helper in the four norm functions. The suggestion below shows how to apply it for rmsnorm. The same change can be applied to the other three functions.

        weight = _maybe_cast_weight(weight, input)

return _flashinfer_norm.rmsnorm(input, weight, eps, out, enable_pdl)


Expand Down Expand Up @@ -162,6 +164,8 @@ def fused_add_rmsnorm(
):
_fused_add_rmsnorm_internal(input, residual, weight, eps, enable_pdl)
else:
if weight.dtype != input.dtype:
weight = weight.to(input.dtype)
_flashinfer_norm.fused_add_rmsnorm(input, residual, weight, eps, enable_pdl)


Expand Down Expand Up @@ -205,6 +209,8 @@ def gemma_rmsnorm(
):
return _gemma_rmsnorm_internal(input, weight, eps, out, enable_pdl)
else:
if weight.dtype != input.dtype:
weight = weight.to(input.dtype)
return _flashinfer_norm.gemma_rmsnorm(input, weight, eps, out, enable_pdl)


Expand Down Expand Up @@ -247,6 +253,8 @@ def gemma_fused_add_rmsnorm(
):
_gemma_fused_add_rmsnorm_internal(input, residual, weight, eps, enable_pdl)
else:
if weight.dtype != input.dtype:
weight = weight.to(input.dtype)
_flashinfer_norm.gemma_fused_add_rmsnorm(
input, residual, weight, eps, enable_pdl
)
Expand Down
Loading