diff --git a/python/sglang/srt/layers/activation.py b/python/sglang/srt/layers/activation.py index 9e7f933d701..b428246a702 100644 --- a/python/sglang/srt/layers/activation.py +++ b/python/sglang/srt/layers/activation.py @@ -28,9 +28,9 @@ get_tensor_model_parallel_world_size, ) from sglang.srt.layers.quantization.base_config import QuantizationConfig -from sglang.srt.utils import is_cuda_available, set_weight_attrs +from sglang.srt.utils import is_cuda, set_weight_attrs -_is_cuda = is_cuda_available() +_is_cuda = is_cuda() if _is_cuda: from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul diff --git a/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py b/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py index 459e43b4818..6b6ede927bf 100644 --- a/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py @@ -3,10 +3,10 @@ import triton.language as tl from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.utils import is_hip +from sglang.srt.utils import is_cuda, is_hip -is_cuda_available = torch.cuda.is_available() -if is_cuda_available: +_is_cuda = is_cuda() +if _is_cuda: CUDA_CAPABILITY = torch.cuda.get_device_capability() _is_hip = is_hip() @@ -1037,12 +1037,12 @@ def extend_attention_fwd( num_warps = 4 else: - if is_cuda_available and CUDA_CAPABILITY[0] >= 9: + if _is_cuda and CUDA_CAPABILITY[0] >= 9: if Lq <= 256: BLOCK_M, BLOCK_N = (128, 64) else: BLOCK_M, BLOCK_N = (32, 64) - elif is_cuda_available and CUDA_CAPABILITY[0] >= 8: + elif _is_cuda and CUDA_CAPABILITY[0] >= 8: if Lq <= 128: BLOCK_M, BLOCK_N = (128, 128) elif Lq <= 256: diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index f6c0173da5c..35e5c21c62f 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -23,10 +23,10 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import ( context_attention_fwd, ) -from sglang.srt.utils import is_hip +from sglang.srt.utils import is_cuda, is_hip -is_cuda_available = torch.cuda.is_available() -if is_cuda_available: +_is_cuda = is_cuda() +if _is_cuda: CUDA_CAPABILITY = torch.cuda.get_device_capability() _is_hip = is_hip() @@ -345,12 +345,12 @@ def extend_attention_fwd( num_warps = 4 else: - if is_cuda_available and CUDA_CAPABILITY[0] >= 9: + if _is_cuda and CUDA_CAPABILITY[0] >= 9: if Lq <= 256: BLOCK_M, BLOCK_N = (128, 64) else: BLOCK_M, BLOCK_N = (32, 64) - elif is_cuda_available and CUDA_CAPABILITY[0] >= 8: + elif _is_cuda and CUDA_CAPABILITY[0] >= 8: # sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K) if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6: if Lq <= 128: diff --git a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py index d022b972147..ac0fc72af14 100644 --- a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py @@ -22,8 +22,12 @@ import triton import triton.language as tl -is_cuda_available = torch.cuda.is_available() -if is_cuda_available: +from sglang.srt.utils import is_cuda, is_hip + +_is_cuda = is_cuda() +_is_hip = is_hip() + +if _is_cuda or _is_hip: CUDA_CAPABILITY = torch.cuda.get_device_capability() @@ -172,7 +176,7 @@ def context_attention_fwd( b_seq_len: [b] out: [b * s, head, head_dim] """ - if is_cuda_available and CUDA_CAPABILITY[0] > 8: + if (_is_cuda or _is_hip) and CUDA_CAPABILITY[0] > 8: BLOCK = 128 else: BLOCK = 64 diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 0359d72349e..3c18cea7046 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -20,9 +20,9 @@ import torch.nn as nn from sglang.srt.custom_op import CustomOp -from sglang.srt.utils import is_cuda_available +from sglang.srt.utils import is_cuda -_is_cuda = is_cuda_available() +_is_cuda = is_cuda() if _is_cuda: from sgl_kernel import ( diff --git a/python/sglang/srt/layers/quantization/modelopt_quant.py b/python/sglang/srt/layers/quantization/modelopt_quant.py index fc32e53f6da..aebc4524436 100644 --- a/python/sglang/srt/layers/quantization/modelopt_quant.py +++ b/python/sglang/srt/layers/quantization/modelopt_quant.py @@ -22,9 +22,9 @@ requantize_with_max_scale, ) from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.utils import is_cuda_available +from sglang.srt.utils import is_cuda -if is_cuda_available(): +if is_cuda(): from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant # Initialize logger for the module diff --git a/python/sglang/srt/layers/quantization/w8a8_int8.py b/python/sglang/srt/layers/quantization/w8a8_int8.py index df345a0a2ef..c263bb3a852 100644 --- a/python/sglang/srt/layers/quantization/w8a8_int8.py +++ b/python/sglang/srt/layers/quantization/w8a8_int8.py @@ -11,10 +11,10 @@ QuantizeMethodBase, ) from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 -from sglang.srt.utils import is_cuda_available, set_weight_attrs +from sglang.srt.utils import is_cuda, set_weight_attrs -is_cuda = is_cuda_available() -if is_cuda: +_is_cuda = is_cuda() +if _is_cuda: from sgl_kernel import int8_scaled_mm diff --git a/python/sglang/srt/layers/rotary_embedding.py b/python/sglang/srt/layers/rotary_embedding.py index 92f5f74e558..95ae14b7246 100644 --- a/python/sglang/srt/layers/rotary_embedding.py +++ b/python/sglang/srt/layers/rotary_embedding.py @@ -8,11 +8,11 @@ import torch.nn as nn from sglang.srt.custom_op import CustomOp -from sglang.srt.utils import is_cuda_available +from sglang.srt.utils import is_cuda -_is_cuda_available = is_cuda_available() +_is_cuda = is_cuda() -if _is_cuda_available: +if _is_cuda: from sgl_kernel import apply_rope_with_cos_sin_cache_inplace else: from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding @@ -82,7 +82,7 @@ def __init__( cache = self._compute_cos_sin_cache() # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability - if not _is_cuda_available: + if not _is_cuda: cache = cache.to(dtype) self.cos_sin_cache: torch.Tensor self.register_buffer("cos_sin_cache", cache, persistent=False) @@ -149,7 +149,7 @@ def forward_cuda( key: torch.Tensor, offsets: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - if _is_cuda_available and (self.head_size in [64, 128, 256, 512]): + if _is_cuda and (self.head_size in [64, 128, 256, 512]): apply_rope_with_cos_sin_cache_inplace( positions=positions, query=query, @@ -652,7 +652,7 @@ def forward_hip(self, *args, **kwargs): def forward(self, *args, **kwargs): if torch.compiler.is_compiling(): return self.forward_native(*args, **kwargs) - if _is_cuda_available: + if _is_cuda: return self.forward_cuda(*args, **kwargs) else: return self.forward_native(*args, **kwargs) diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index e0f434a1974..f75fcdf98b7 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -10,9 +10,9 @@ from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo -from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available +from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda -if is_cuda_available(): +if is_cuda(): from sgl_kernel import ( min_p_sampling_from_probs, top_k_renorm_prob, diff --git a/python/sglang/srt/models/minicpm3.py b/python/sglang/srt/models/minicpm3.py index c7053f78f63..1156c3e470d 100644 --- a/python/sglang/srt/models/minicpm3.py +++ b/python/sglang/srt/models/minicpm3.py @@ -40,9 +40,9 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_loader.weight_utils import default_weight_loader -from sglang.srt.utils import add_prefix, is_cuda_available +from sglang.srt.utils import add_prefix, is_cuda -if is_cuda_available(): +if is_cuda(): from sgl_kernel import bmm_fp8 diff --git a/python/sglang/srt/speculative/build_eagle_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py index 364ca0677d0..c2840c4a08d 100644 --- a/python/sglang/srt/speculative/build_eagle_tree.py +++ b/python/sglang/srt/speculative/build_eagle_tree.py @@ -4,9 +4,9 @@ import torch -from sglang.srt.utils import is_cuda_available, is_hip +from sglang.srt.utils import is_cuda, is_hip -if is_cuda_available() or is_hip(): +if is_cuda() or is_hip(): from sgl_kernel import ( build_tree_kernel_efficient as sgl_build_tree_kernel_efficient, ) diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 10c9e54c243..b4d339ccea2 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -19,9 +19,9 @@ from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient -from sglang.srt.utils import fast_topk, is_cuda_available, is_hip, next_power_of_2 +from sglang.srt.utils import fast_topk, is_cuda, is_hip, next_power_of_2 -if is_cuda_available(): +if is_cuda(): from sgl_kernel import ( top_k_renorm_prob, top_p_renorm_prob, diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 06beee8d54b..64de7dbb491 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -34,14 +34,9 @@ select_top_k_tokens, ) from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from sglang.srt.utils import ( - empty_context, - fast_topk, - get_available_gpu_memory, - is_cuda_available, -) +from sglang.srt.utils import empty_context, fast_topk, get_available_gpu_memory, is_cuda -if is_cuda_available(): +if is_cuda(): from sgl_kernel import segment_packbits logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index a6c2d910bb5..22430a36d83 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -130,10 +130,6 @@ def is_flashinfer_available(): return importlib.util.find_spec("flashinfer") is not None and is_cuda() -def is_cuda_available(): - return is_cuda() - - _ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var( "SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false" )