Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions flashinfer/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,19 @@
from .utils import (
_get_cache_buf,
device_support_pdl,
get_default_generators,
register_custom_op,
register_fake_op,
)


def get_seed_and_offset(
increment: int, generator: Optional[torch.Generator] = None
increment: int,
generator: Optional[torch.Generator] = None,
device: Optional[torch.device] = None,
) -> Tuple[int, int]:
if generator is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = torch.Generator(device=device)
generator = get_default_generators(device)
# add mutex if multi-trheading needed
state = generator.get_state()
seed, offset = state.view(torch.int64)
Expand Down Expand Up @@ -100,7 +102,9 @@ def sampling_from_logits(
out_dtype = indices.dtype if indices is not None else torch.int32
samples = torch.empty(batch_size, dtype=out_dtype, device=device)
if seed is None or offset is None:
seed, offset = get_seed_and_offset(batch_size * logits.size(1), generator)
seed, offset = get_seed_and_offset(
batch_size * logits.size(1), generator, device
)
module.sampling_from_logits(
logits,
samples,
Expand Down Expand Up @@ -139,7 +143,7 @@ def sampling_from_probs(
out_dtype = indices.dtype if indices is not None else torch.int32
samples = torch.empty(batch_size, dtype=out_dtype, device=device)
if seed is None or offset is None:
seed, offset = get_seed_and_offset(batch_size, generator)
seed, offset = get_seed_and_offset(batch_size, generator, device)
module.sampling_from_probs(
probs,
samples,
Expand Down Expand Up @@ -185,7 +189,7 @@ def top_p_sampling_from_probs(
out_dtype = indices.dtype if indices is not None else torch.int32
samples = torch.empty(batch_size, dtype=out_dtype, device=device)
if seed is None or offset is None:
seed, offset = get_seed_and_offset(batch_size * 32, generator)
seed, offset = get_seed_and_offset(batch_size * 32, generator, device)
module.top_p_sampling_from_probs(
probs,
samples,
Expand Down Expand Up @@ -232,7 +236,7 @@ def top_k_sampling_from_probs(
out_dtype = indices.dtype if indices is not None else torch.int32
samples = torch.empty(batch_size, dtype=out_dtype, device=device)
if seed is None or offset is None:
seed, offset = get_seed_and_offset(batch_size * 32, generator)
seed, offset = get_seed_and_offset(batch_size * 32, generator, device)
module.top_k_sampling_from_probs(
probs,
samples,
Expand Down Expand Up @@ -281,7 +285,7 @@ def min_p_sampling_from_probs(
out_dtype = indices.dtype if indices is not None else torch.int32
samples = torch.empty(batch_size, dtype=out_dtype, device=device)
if seed is None or offset is None:
seed, offset = get_seed_and_offset(batch_size, generator)
seed, offset = get_seed_and_offset(batch_size, generator, device)
module.min_p_sampling_from_probs(
probs,
samples,
Expand Down Expand Up @@ -319,7 +323,7 @@ def top_k_top_p_sampling_from_probs(
out_dtype = indices.dtype if indices is not None else torch.int32
samples = torch.empty(batch_size, dtype=out_dtype, device=device)
if seed is None or offset is None:
seed, offset = get_seed_and_offset(batch_size * 32, generator)
seed, offset = get_seed_and_offset(batch_size * 32, generator, device)
module.top_k_top_p_sampling_from_probs(
probs,
samples,
Expand Down Expand Up @@ -480,7 +484,7 @@ def chain_speculative_sampling(
output_token_ids = torch.empty((b, n + 1), dtype=torch.int32, device=device)
if seed is None or offset is None:
seed, offset = get_seed_and_offset(
draft_probs.size(0) * (draft_probs.size(1) + 1), generator
draft_probs.size(0) * (draft_probs.size(1) + 1), generator, device
)
module.chain_speculative_sampling(
draft_probs,
Expand Down
6 changes: 6 additions & 0 deletions flashinfer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1182,3 +1182,9 @@ def wrapper(*args, **kwargs):
return wrapper

return decorator


@functools.cache
def get_default_generators(device: torch.device):
torch.cuda.init()
return torch.cuda.default_generators[device.index]
Comment on lines +1187 to +1190
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add device validation to prevent cryptic runtime errors.

The function lacks validation for the device parameter, which can lead to confusing errors:

  • If device.type != "cuda", accessing torch.cuda.default_generators will fail
  • If device.index is None (e.g., torch.device("cuda")), the indexing operation will fail
  • If device.index exceeds the number of CUDA devices, it will raise an IndexError
πŸ›‘οΈ Proposed fix with device validation
 @functools.cache
 def get_default_generators(device: torch.device):
+    if device.type != "cuda":
+        raise ValueError(f"Device must be a CUDA device, got {device.type}")
+    if device.index is None:
+        device = torch.device("cuda", torch.cuda.current_device())
     torch.cuda.init()
     return torch.cuda.default_generators[device.index]

Committable suggestion skipped: line range outside the PR's diff.

πŸ€– Prompt for AI Agents
In @flashinfer/utils.py around lines 1187 - 1190, Validate the incoming
torch.device in get_default_generators: ensure device.type == "cuda", that
device.index is not None, call torch.cuda.init() and then confirm device.index
is within range (0 <= device.index < torch.cuda.device_count()); if any check
fails raise a ValueError with a clear message describing the problem (e.g.
non-CUDA device, unspecified index, or index out of range), otherwise return
torch.cuda.default_generators[device.index]; keep the @functools.cache decorator
and reference get_default_generators and torch.cuda.default_generators in your
changes.