Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/sglang/srt/layers/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
is_hip,
is_sm90_supported,
is_sm100_supported,
prepare_weight_cache,
)

_is_flashinfer_available = is_flashinfer_available()
Expand Down Expand Up @@ -275,7 +276,11 @@ def prepare_mlp(
hidden_states: torch.Tensor,
residual: torch.Tensor,
forward_batch: ForwardBatch,
cache=None,
):
if cache is not None:
self._context.cache = cache

return self._communicate_with_all_reduce_and_layer_norm_fn(
hidden_states=hidden_states,
residual=residual,
Expand Down Expand Up @@ -349,6 +354,7 @@ class CommunicateContext:
attn_tp_size: int
attn_dp_size: int
tp_size: int
cache = None

def is_same_group_size(self, a: ScatterMode, b: ScatterMode):
return self.process_group_sizes[a] == self.process_group_sizes[b]
Expand Down Expand Up @@ -533,6 +539,8 @@ def _gather_hidden_states_and_residual(
)
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
if context.cache is not None:
_ = prepare_weight_cache(hidden_states, context.cache)
hidden_states, residual = layernorm(hidden_states, residual)
return hidden_states, residual

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/layers/quantization/w8a8_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,7 @@ def process_weights_after_loading(self, layer):
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)


class NPU_W8A8LinearMethodMTImpl:
Expand Down Expand Up @@ -830,6 +831,7 @@ def process_weights_after_loading(self, layer):
layer.weight_scale.data = layer.weight_scale.data.flatten()
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten()
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)


class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/model_executor/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,13 @@ def add_mla_attention_backend(backend_name):
logger = logging.getLogger(__name__)


if _is_npu:
import torch_npu

torch.npu.config.allow_internal_format = True
torch_npu.npu.set_compile_mode(jit_compile=False)


class RankZeroFilter(logging.Filter):
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""

Expand Down
2 changes: 2 additions & 0 deletions python/sglang/srt/model_executor/npu_graph_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@
import threading
from typing import TYPE_CHECKING, Optional, Union

import numpy as np
import torch

from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner

logger = logging.getLogger(__name__)
Expand Down
20 changes: 18 additions & 2 deletions python/sglang/srt/models/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,19 @@
)
from sglang.srt.models.qwen2 import Qwen2MLP as Qwen3MLP
from sglang.srt.models.qwen2 import Qwen2Model
from sglang.srt.utils import add_prefix, is_cuda
from sglang.srt.utils import (
add_prefix,
get_cmo_stream,
is_cuda,
is_npu,
wait_cmo_stream,
)

Qwen3Config = None

logger = logging.getLogger(__name__)
_is_cuda = is_cuda()
_is_npu = is_npu()


class Qwen3Attention(nn.Module):
Expand Down Expand Up @@ -235,9 +242,18 @@ def forward(

# Fully Connected
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
hidden_states,
residual,
forward_batch,
cache=(
[self.mlp.gate_up_proj.weight, self.mlp.down_proj.weight]
if _is_npu
else None
),
)
hidden_states = self.mlp(hidden_states)
if _is_npu and get_cmo_stream():
wait_cmo_stream()
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
Expand Down
44 changes: 44 additions & 0 deletions python/sglang/srt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,50 @@ def make_layers(
return modules, start_layer, end_layer


cmo_stream = None


def get_cmo_stream():
"""
Cache Management Operation(CMO).
Launch a new stream to prefetch the weight of matmul when running other
AIV or communication kernels, aiming to overlap the memory access time.
"""
global cmo_stream
if cmo_stream is None:
cmo_stream = torch.get_device_module().Stream()
return cmo_stream


def prepare_weight_cache(handle, cache):
import torch_npu

NPU_PREFETCH_MAX_SIZE_BYTES = (
1000000000 # 1GB, a large value to prefetch entire weight
)
stream = get_cmo_stream()
stream.wait_stream(torch.npu.current_stream())
with torch.npu.stream(stream):
if isinstance(cache, list):
for weight in cache:
torch_npu.npu_prefetch(
weight,
handle,
NPU_PREFETCH_MAX_SIZE_BYTES,
)
else:
torch_npu.npu_prefetch(
cache,
handle,
NPU_PREFETCH_MAX_SIZE_BYTES,
)


def wait_cmo_stream():
cur_stream = torch.get_device_module().current_stream()
cur_stream.wait_stream(get_cmo_stream())


def set_random_seed(seed: int) -> None:
"""Set the random seed for all libraries."""
random.seed(seed)
Expand Down
Loading