diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index eaecdc54bcf4..b4d00946150b 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -63,7 +63,7 @@ is_npu, support_triton, ) -from sglang.srt.utils.common import ceil_align +from sglang.srt.utils.common import ceil_align, is_pin_memory_available if TYPE_CHECKING: from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -480,6 +480,7 @@ def init_new( rids=[req.rid for req in batch.reqs], ) device = model_runner.device + _pin = is_pin_memory_available(device) if batch.extend_input_logprob_token_ids is not None: ret.extend_input_logprob_token_ids_gpu = ( @@ -488,9 +489,9 @@ def init_new( num_tokens = len(batch.input_ids) if batch.input_ids is not None else 0 if enable_num_token_non_padded(model_runner.server_args): - ret.num_token_non_padded = torch.tensor(num_tokens, dtype=torch.int32).to( - device, non_blocking=True - ) + ret.num_token_non_padded = torch.tensor( + num_tokens, dtype=torch.int32, pin_memory=_pin + ).to(device, non_blocking=True) ret.num_token_non_padded_cpu = num_tokens # For MLP sync @@ -510,15 +511,18 @@ def init_new( ret.original_global_num_tokens_cpu = batch.global_num_tokens ret.global_num_tokens_cpu = global_num_tokens ret.global_num_tokens_gpu = torch.tensor( - global_num_tokens, dtype=torch.int64 + global_num_tokens, dtype=torch.int64, pin_memory=_pin ).to(device, non_blocking=True) ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob ret.global_num_tokens_for_logprob_gpu = torch.tensor( - global_num_tokens_for_logprob, dtype=torch.int64 + global_num_tokens_for_logprob, dtype=torch.int64, pin_memory=_pin ).to(device, non_blocking=True) if ret.forward_mode.is_idle(): + if _is_npu: + # This synchronize is necessary to prevent the system from hanging on npu. + torch.npu.synchronize() ret.positions = torch.empty((0,), dtype=torch.int64, device=device) return ret @@ -534,6 +538,7 @@ def init_new( for i in range(block_offset, block_offset + block_size) ], dtype=positions_dtype, + pin_memory=_pin, ).to(device, non_blocking=True) elif ( ret.spec_info is not None @@ -549,10 +554,10 @@ def init_new( assert isinstance(batch.extend_seq_lens, list) assert isinstance(batch.extend_prefix_lens, list) ret.extend_seq_lens = torch.tensor( - batch.extend_seq_lens, dtype=torch.int32 + batch.extend_seq_lens, dtype=torch.int32, pin_memory=_pin ).to(device, non_blocking=True) ret.extend_prefix_lens = torch.tensor( - batch.extend_prefix_lens, dtype=torch.int32 + batch.extend_prefix_lens, dtype=torch.int32, pin_memory=_pin ).to(device, non_blocking=True) ret.extend_num_tokens = batch.extend_num_tokens positions, ret.extend_start_loc = compute_position( @@ -755,6 +760,7 @@ def _compute_mrope_positions( # batch_size * [3 * seq_len] batch_size = self.seq_lens_cpu.shape[0] mrope_positions_list = [[]] * batch_size + _pin = is_pin_memory_available(model_runner.device) for batch_idx in range(batch_size): mm_input = batch.multimodal_inputs[batch_idx] if self.forward_mode.is_decode(): @@ -806,10 +812,20 @@ def _compute_mrope_positions( ) mrope_positions_list[batch_idx] = mrope_positions - self.mrope_positions = torch.cat( - [pos for pos in mrope_positions_list], - dim=1, - ).to(dtype=torch.int64, device=model_runner.device, non_blocking=True) + if _pin: + self.mrope_positions = ( + torch.cat( + [pos for pos in mrope_positions_list], + dim=1, + ) + .pin_memory() + .to(dtype=torch.int64, device=model_runner.device, non_blocking=True) + ) + else: + self.mrope_positions = torch.cat( + [pos for pos in mrope_positions_list], + dim=1, + ).to(dtype=torch.int64, device=model_runner.device, non_blocking=True) def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0): if value == 0: diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index b0be70d75193..0d968607dd05 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -601,12 +601,18 @@ def get_available_gpu_memory( return free_gpu_memory / (1 << 30) +_is_cuda = is_cuda() +_is_npu = is_npu() + + def is_pin_memory_available(device=None) -> bool: - if not torch.cuda.is_available(): - return False - if device is not None and str(device) == "cpu": - return False - return True + if device is None: + return _is_cuda or _is_npu + if str(device) == "cuda": + return _is_cuda + if str(device) == "npu": + return _is_npu + return False class LayerFn(Protocol):