Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/backend/native_api.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@
"server_process, port = launch_server_cmd(\n",
" \"python3 -m sglang.launch_server --model-path qwen/qwen2.5-0.5b-instruct --host 0.0.0.0\"\n",
")\n",
"## To run qwen2.5-0.5b-instruct model on the Ascend-Npu, you can execute the following command:\n",
"# server_process, port = launch_server_cmd(\n",
"# \"python3 -m sglang.launch_server --model-path qwen/qwen2.5-0.5b-instruct --host 0.0.0.0 --device npu --tp 2 --attention-backend torch_native\"\n",
"# )\n",
"\n",
"wait_for_server(f\"http://localhost:{port}\")"
]
Expand Down
4 changes: 2 additions & 2 deletions python/sglang/srt/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import torch

from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu
from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu, is_npu

logger = logging.getLogger(__name__)
use_vllm_custom_allreduce = get_bool_env_var(
Expand All @@ -25,7 +25,7 @@
logger.warning("Failed to import from custom_ar with %r", e)


if not is_hip():
if not is_hip() and not is_npu():
if use_vllm_custom_allreduce:
custom_op = torch.ops._C_custom_ar
else:
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.utils import is_cuda, set_weight_attrs
from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs
from sglang.utils import resolve_obj_by_qualname

_is_cuda = is_cuda()
_is_npu = is_npu()

if _is_cuda:
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
Expand Down Expand Up @@ -184,7 +185,7 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
return nn.Identity()


if not _is_cuda:
if not _is_cuda and not _is_npu:
logger.info(
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
)
Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
import torch.nn as nn

from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip
from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip, is_npu

_is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip

if _is_cuda:
Expand Down Expand Up @@ -187,7 +188,7 @@ def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"


if not (_is_cuda or _is_hip):
if not (_is_cuda or _is_hip or _is_npu):
logger.info(
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
per_tensor_dequantize,
replace_parameter,
)
from sglang.srt.utils import is_cuda, set_weight_attrs
from sglang.srt.utils import is_cuda, is_npu, set_weight_attrs

_is_cuda = is_cuda()
_is_npu = is_npu()

if not _is_cuda:
if not _is_cuda and not _is_npu:
from vllm import _custom_ops as vllm_ops
from vllm._custom_ops import scaled_fp8_quant

Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,15 @@ def dummy_func(*args, **kwargs):
get_bool_env_var,
is_cuda,
is_hip,
is_npu,
log_info_on_rank0,
print_warning_once,
set_weight_attrs,
)

_is_hip = is_hip()
_is_cuda = is_cuda()
_is_npu = is_npu()

_is_fp8_fnuz = is_fp8_fnuz()

Expand All @@ -86,7 +88,7 @@ def dummy_func(*args, **kwargs):
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight

if not _is_cuda:
if not _is_cuda and not _is_npu:
from vllm._custom_ops import scaled_fp8_quant


Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/layers/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import torch

from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
from sglang.srt.utils import is_cuda
from sglang.srt.utils import is_cuda, is_npu

_is_cuda = is_cuda()
_is_npu = is_npu()

if not _is_cuda:
if not _is_cuda and not _is_npu:
from vllm._custom_ops import scaled_fp8_quant


Expand Down
5 changes: 3 additions & 2 deletions python/sglang/srt/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@
import torch.nn as nn

from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda, is_hip
from sglang.srt.utils import is_cuda, is_hip, is_npu

_is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()

if _is_cuda:
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
Expand Down Expand Up @@ -84,7 +85,7 @@ def __init__(
if not _is_cuda:
cache = cache.to(dtype)

if not _is_cuda or self.head_size not in [64, 128, 256, 512]:
if not (_is_cuda or _is_npu) or self.head_size not in [64, 128, 256, 512]:
from vllm._custom_ops import rotary_embedding

self.vllm_rotary_embedding = rotary_embedding
Expand Down
42 changes: 39 additions & 3 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1291,13 +1291,24 @@ def get_hpu_memory_capacity():
)


def get_npu_memory_capacity():
try:
import torch_npu

return torch.npu.mem_get_info()[1] // 1024 // 1024 # unit: MB
except ImportError as e:
raise ImportError("torch_npu is required when run on npu device.")


def get_device_memory_capacity(device: str = None):
if is_cuda():
gpu_mem = get_nvgpu_memory_capacity()
elif is_hip():
gpu_mem = get_amdgpu_memory_capacity()
elif device == "hpu":
gpu_mem = get_hpu_memory_capacity()
elif device == "npu":
gpu_mem = get_npu_memory_capacity()
else:
# GPU memory is not known yet or no GPU is available.
gpu_mem = None
Expand Down Expand Up @@ -1423,6 +1434,11 @@ def get_device(device_id: Optional[int] = None) -> str:
return "xpu"
return "xpu:{}".format(device_id)

if hasattr(torch, "npu") and torch.npu.is_available():
if device_id == None:
return "npu"
return "npu:{}".format(device_id)

if is_habana_available():
try:
import habana_frameworks.torch.hpu
Expand Down Expand Up @@ -1497,15 +1513,35 @@ def get_device_capability(device_id: int = 0) -> Tuple[int, int]:
return major, minor


def get_npu_compiler_config():
config = {
"frozen_parameter": True,
"tiling_schedule_optimize": True,
"topology_sorting_strategy": "StableRDFS",
}
return config


def get_compiler_backend() -> str:
if hasattr(torch, "hpu") and torch.hpu.is_available():
return "hpu_backend"

if hasattr(torch, "npu") and torch.npu.is_available():
import torchair
try:
import torchair
import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce
from torchair.configs.compiler_config import CompilerConfig
except ImportError as e:
raise ImportError(
"NPU detected, but torchair package is not installed. "
"Please install torchair for torch.compile support on NPU."
)
compiler_config = CompilerConfig()
predefined_config = get_npu_compiler_config()
for k, v in predefined_config.items():
setattr(compiler_config.experimental_config, k, v)

config = torchair.CompilerConfig()
npu_backend = torchair.get_npu_backend(compiler_config=config)
npu_backend = torchair.get_npu_backend(compiler_config=compiler_config)
return npu_backend

return "inductor"
Expand Down
Loading