diff --git a/sgl-kernel/python/sgl_kernel/elementwise.py b/sgl-kernel/python/sgl_kernel/elementwise.py index 0e15cfff317..0e2bbc9904d 100644 --- a/sgl-kernel/python/sgl_kernel/elementwise.py +++ b/sgl-kernel/python/sgl_kernel/elementwise.py @@ -1,7 +1,7 @@ from typing import Optional import torch -from sgl_kernel.utils import get_cuda_stream +from sgl_kernel.utils import get_cuda_stream, is_hopper_arch # These implementations extensively draw from and build upon the FlashInfer project https://github.com/flashinfer-ai/flashinfer @@ -11,7 +11,7 @@ def rmsnorm( weight: torch.Tensor, eps: float = 1e-6, out: Optional[torch.Tensor] = None, - enable_pdl: bool = False, + enable_pdl: Optional[bool] = None, ) -> torch.Tensor: r"""Root mean square normalization. @@ -27,9 +27,10 @@ def rmsnorm( Epsilon for numerical stability. out: Optional[torch.Tensor] The output tensor, if specified, the kernel will update this tensor inplace. - enable_pdl: bool + enable_pdl: Optional[bool] Whether to enable `programmatic dependent launch `_ + If None, will be automatically enabled on Hopper architecture. Returns ------- @@ -38,6 +39,8 @@ def rmsnorm( """ if out is None: out = torch.empty_like(input) + if enable_pdl is None: + enable_pdl = is_hopper_arch() torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl) return out @@ -47,7 +50,7 @@ def fused_add_rmsnorm( residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, - enable_pdl: bool = False, + enable_pdl: Optional[bool] = None, ) -> None: r"""Fused add root mean square normalization. @@ -67,10 +70,13 @@ def fused_add_rmsnorm( Weight tensor, shape (hidden_size,). eps: float Epsilon for numerical stability. - enable_pdl: bool + enable_pdl: Optional[bool] Whether to enable `programmatic dependent launch `_ + If None, will be automatically enabled on Hopper architecture. """ + if enable_pdl is None: + enable_pdl = is_hopper_arch() torch.ops.sgl_kernel.fused_add_rmsnorm.default( input, residual, weight, eps, enable_pdl ) @@ -81,7 +87,7 @@ def gemma_rmsnorm( weight: torch.Tensor, eps: float = 1e-6, out: Optional[torch.Tensor] = None, - enable_pdl: bool = False, + enable_pdl: Optional[bool] = None, ) -> torch.Tensor: r"""Gemma-style root mean square normalization. @@ -97,9 +103,10 @@ def gemma_rmsnorm( Epsilon for numerical stability. out: Optional[torch.Tensor] The output tensor, if specified, the kernel will update this tensor inplace. - enable_pdl: bool + enable_pdl: Optional[bool] Whether to enable `programmatic dependent launch `_ + If None, will be automatically enabled on Hopper architecture. Returns ------- @@ -108,6 +115,8 @@ def gemma_rmsnorm( """ if out is None: out = torch.empty_like(input) + if enable_pdl is None: + enable_pdl = is_hopper_arch() torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl) return out @@ -117,7 +126,7 @@ def gemma_fused_add_rmsnorm( residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6, - enable_pdl: bool = False, + enable_pdl: Optional[bool] = None, ) -> None: r"""Gemma-style fused add root mean square normalization. @@ -137,10 +146,13 @@ def gemma_fused_add_rmsnorm( Weight tensor, shape (hidden_size,). eps: float Epsilon for numerical stability. - enable_pdl: bool + enable_pdl: Optional[bool] Whether to enable `programmatic dependent launch `_ + If None, will be automatically enabled on Hopper architecture. """ + if enable_pdl is None: + enable_pdl = is_hopper_arch() torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default( input, residual, weight, eps, enable_pdl ) diff --git a/sgl-kernel/python/sgl_kernel/utils.py b/sgl-kernel/python/sgl_kernel/utils.py index d930678f20b..63f14624190 100644 --- a/sgl-kernel/python/sgl_kernel/utils.py +++ b/sgl-kernel/python/sgl_kernel/utils.py @@ -39,3 +39,10 @@ def _to_tensor_scalar_tuple(x): return (x, 0) else: return (None, x) + + +def is_hopper_arch() -> bool: + # Hopper arch's compute capability == 9.0 + device = torch.cuda.current_device() + major, minor = torch.cuda.get_device_capability(device) + return major == 9