Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions sgl-kernel/python/sgl_kernel/elementwise.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Optional

import torch
from sgl_kernel.utils import get_cuda_stream
from sgl_kernel.utils import get_cuda_stream, is_hopper_arch


# These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer
Expand All @@ -11,7 +11,7 @@ def rmsnorm(
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
enable_pdl: bool = False,
enable_pdl: Optional[bool] = None,
) -> torch.Tensor:
r"""Root mean square normalization.

Expand All @@ -27,9 +27,10 @@ def rmsnorm(
Epsilon for numerical stability.
out: Optional[torch.Tensor]
The output tensor, if specified, the kernel will update this tensor inplace.
enable_pdl: bool
enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.

Returns
-------
Expand All @@ -38,6 +39,8 @@ def rmsnorm(
"""
if out is None:
out = torch.empty_like(input)
if enable_pdl is None:
enable_pdl = is_hopper_arch()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should cache this result instead of calling it every time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will do it later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More precisely, hopper or later architectures.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl)
return out

Expand All @@ -47,7 +50,7 @@ def fused_add_rmsnorm(
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
enable_pdl: bool = False,
enable_pdl: Optional[bool] = None,
) -> None:
r"""Fused add root mean square normalization.

Expand All @@ -67,10 +70,13 @@ def fused_add_rmsnorm(
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
enable_pdl: bool
enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.
"""
if enable_pdl is None:
enable_pdl = is_hopper_arch()
torch.ops.sgl_kernel.fused_add_rmsnorm.default(
input, residual, weight, eps, enable_pdl
)
Expand All @@ -81,7 +87,7 @@ def gemma_rmsnorm(
weight: torch.Tensor,
eps: float = 1e-6,
out: Optional[torch.Tensor] = None,
enable_pdl: bool = False,
enable_pdl: Optional[bool] = None,
) -> torch.Tensor:
r"""Gemma-style root mean square normalization.

Expand All @@ -97,9 +103,10 @@ def gemma_rmsnorm(
Epsilon for numerical stability.
out: Optional[torch.Tensor]
The output tensor, if specified, the kernel will update this tensor inplace.
enable_pdl: bool
enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.

Returns
-------
Expand All @@ -108,6 +115,8 @@ def gemma_rmsnorm(
"""
if out is None:
out = torch.empty_like(input)
if enable_pdl is None:
enable_pdl = is_hopper_arch()
torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl)
return out

Expand All @@ -117,7 +126,7 @@ def gemma_fused_add_rmsnorm(
residual: torch.Tensor,
weight: torch.Tensor,
eps: float = 1e-6,
enable_pdl: bool = False,
enable_pdl: Optional[bool] = None,
) -> None:
r"""Gemma-style fused add root mean square normalization.

Expand All @@ -137,10 +146,13 @@ def gemma_fused_add_rmsnorm(
Weight tensor, shape (hidden_size,).
eps: float
Epsilon for numerical stability.
enable_pdl: bool
enable_pdl: Optional[bool]
Whether to enable `programmatic dependent launch
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
If None, will be automatically enabled on Hopper architecture.
"""
if enable_pdl is None:
enable_pdl = is_hopper_arch()
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
input, residual, weight, eps, enable_pdl
)
Expand Down
7 changes: 7 additions & 0 deletions sgl-kernel/python/sgl_kernel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,10 @@ def _to_tensor_scalar_tuple(x):
return (x, 0)
else:
return (None, x)


def is_hopper_arch() -> bool:
# Hopper arch's compute capability == 9.0
device = torch.cuda.current_device()
major, minor = torch.cuda.get_device_capability(device)
return major == 9
Loading