Skip to content
7 changes: 3 additions & 4 deletions python/sglang/srt/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,16 @@ def forward(
probs, sampling_info.min_ps
)
else:
# Check Nan will throw exception, only check when crash_on_warnings is True
check_nan = self.use_nan_detection and crash_on_warnings()
Copy link
Contributor Author

@yubofredwang yubofredwang Apr 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding another comment for future reference: because we already check for Nan for logits before they get converted into probs. The chance of probs being Nan is quite low. Even when probs contain NaN, sampling was still able to generate tokens

batch_next_token_ids = top_k_top_p_sampling_from_probs(
probs,
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",
check_nan=check_nan,
)

if self.use_nan_detection:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding a comment here as well: we can remove this check as the new flash infer sampling implementation guarantees the generation of enough tokens. See PR: flashinfer-ai/flashinfer#912

logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)

elif global_server_args_dict["sampling_backend"] == "pytorch":
# A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
Expand Down
10 changes: 5 additions & 5 deletions sgl-kernel/python/sgl_kernel/sampling.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Tuple, Union
from typing import Optional, Union

import torch
from sgl_kernel.utils import _to_tensor_scalar_tuple, get_cuda_stream
Expand Down Expand Up @@ -109,7 +109,7 @@ def _top_p_sampling_from_probs_internal(
top_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
with probs.device as device:
probs = probs.float()
maybe_top_p_arr = (
Expand All @@ -135,7 +135,7 @@ def top_p_sampling_from_probs(
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-p sampling (nucleus sampling) from probabilities,
this operator implements GPU-based rejection sampling without explicit sorting.
Expand Down Expand Up @@ -194,7 +194,7 @@ def _top_k_top_p_sampling_from_probs_internal(
top_p_val: float,
deterministic: bool,
generator: Optional[torch.Generator],
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
with probs.device as device:
probs = probs.float()
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
Expand Down Expand Up @@ -225,7 +225,7 @@ def top_k_top_p_sampling_from_probs(
deterministic: bool = True,
generator: Optional[torch.Generator] = None,
check_nan: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> torch.Tensor:
r"""Adapt from https://github.com/flashinfer-ai/flashinfer/flashinfer/sampling.py
Fused GPU kernel for top-k and top-p sampling from probabilities,

Expand Down
Loading