Skip to content
40 changes: 28 additions & 12 deletions python/sglang/srt/model_executor/forward_batch_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
is_npu,
support_triton,
)
from sglang.srt.utils.common import ceil_align
from sglang.srt.utils.common import ceil_align, is_pin_memory_available
Comment thread
litmei marked this conversation as resolved.

if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
Expand Down Expand Up @@ -480,6 +480,7 @@ def init_new(
rids=[req.rid for req in batch.reqs],
)
device = model_runner.device
_pin = is_pin_memory_available(device)

if batch.extend_input_logprob_token_ids is not None:
ret.extend_input_logprob_token_ids_gpu = (
Expand All @@ -488,9 +489,9 @@ def init_new(

num_tokens = len(batch.input_ids) if batch.input_ids is not None else 0
if enable_num_token_non_padded(model_runner.server_args):
ret.num_token_non_padded = torch.tensor(num_tokens, dtype=torch.int32).to(
device, non_blocking=True
)
ret.num_token_non_padded = torch.tensor(
Comment thread
litmei marked this conversation as resolved.
num_tokens, dtype=torch.int32, pin_memory=_pin
).to(device, non_blocking=True)
ret.num_token_non_padded_cpu = num_tokens

# For MLP sync
Expand All @@ -510,15 +511,18 @@ def init_new(
ret.original_global_num_tokens_cpu = batch.global_num_tokens
ret.global_num_tokens_cpu = global_num_tokens
ret.global_num_tokens_gpu = torch.tensor(
global_num_tokens, dtype=torch.int64
global_num_tokens, dtype=torch.int64, pin_memory=_pin
).to(device, non_blocking=True)

ret.global_num_tokens_for_logprob_cpu = global_num_tokens_for_logprob
ret.global_num_tokens_for_logprob_gpu = torch.tensor(
global_num_tokens_for_logprob, dtype=torch.int64
global_num_tokens_for_logprob, dtype=torch.int64, pin_memory=_pin
).to(device, non_blocking=True)

if ret.forward_mode.is_idle():
if _is_npu:
# This synchronize is necessary to prevent the system from hanging on npu.
torch.npu.synchronize()
ret.positions = torch.empty((0,), dtype=torch.int64, device=device)
return ret

Expand All @@ -534,6 +538,7 @@ def init_new(
for i in range(block_offset, block_offset + block_size)
],
dtype=positions_dtype,
pin_memory=_pin,
).to(device, non_blocking=True)
elif (
ret.spec_info is not None
Expand All @@ -549,10 +554,10 @@ def init_new(
assert isinstance(batch.extend_seq_lens, list)
assert isinstance(batch.extend_prefix_lens, list)
ret.extend_seq_lens = torch.tensor(
batch.extend_seq_lens, dtype=torch.int32
batch.extend_seq_lens, dtype=torch.int32, pin_memory=_pin
).to(device, non_blocking=True)
ret.extend_prefix_lens = torch.tensor(
batch.extend_prefix_lens, dtype=torch.int32
batch.extend_prefix_lens, dtype=torch.int32, pin_memory=_pin
).to(device, non_blocking=True)
ret.extend_num_tokens = batch.extend_num_tokens
positions, ret.extend_start_loc = compute_position(
Expand Down Expand Up @@ -755,6 +760,7 @@ def _compute_mrope_positions(
# batch_size * [3 * seq_len]
batch_size = self.seq_lens_cpu.shape[0]
mrope_positions_list = [[]] * batch_size
_pin = is_pin_memory_available(model_runner.device)
for batch_idx in range(batch_size):
mm_input = batch.multimodal_inputs[batch_idx]
if self.forward_mode.is_decode():
Expand Down Expand Up @@ -806,10 +812,20 @@ def _compute_mrope_positions(
)
mrope_positions_list[batch_idx] = mrope_positions

self.mrope_positions = torch.cat(
[pos for pos in mrope_positions_list],
dim=1,
).to(dtype=torch.int64, device=model_runner.device, non_blocking=True)
if _pin:
self.mrope_positions = (
torch.cat(
[pos for pos in mrope_positions_list],
dim=1,
)
.pin_memory()
.to(dtype=torch.int64, device=model_runner.device, non_blocking=True)
)
else:
self.mrope_positions = torch.cat(
[pos for pos in mrope_positions_list],
dim=1,
).to(dtype=torch.int64, device=model_runner.device, non_blocking=True)

def _pad_tensor_to_size(self, tensor: torch.Tensor, size: int, *, value: int = 0):
if value == 0:
Expand Down
16 changes: 11 additions & 5 deletions python/sglang/srt/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,12 +601,18 @@ def get_available_gpu_memory(
return free_gpu_memory / (1 << 30)


_is_cuda = is_cuda()
_is_npu = is_npu()


def is_pin_memory_available(device=None) -> bool:
if not torch.cuda.is_available():
return False
if device is not None and str(device) == "cpu":
return False
return True
if device is None:
return _is_cuda or _is_npu
if str(device) == "cuda":
return _is_cuda
if str(device) == "npu":
return _is_npu
return False


class LayerFn(Protocol):
Expand Down
Loading