From 9bb2e0453e0a28e111d5a29780544768975a5500 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Tue, 6 Jan 2026 07:48:59 +0000 Subject: [PATCH 1/2] bugfix: set seed for default torch generator --- flashinfer/sampling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index d31e5889ea..aca2d811e5 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -35,6 +35,7 @@ def get_seed_and_offset( if generator is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") generator = torch.Generator(device=device) + generator.seed() # add mutex if multi-trheading needed state = generator.get_state() seed, offset = state.view(torch.int64) From b8c9e71dac53bedfa6dada081e9e08e4b737aae9 Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Wed, 7 Jan 2026 08:13:03 +0000 Subject: [PATCH 2/2] upd --- flashinfer/sampling.py | 25 ++++++++++++++----------- flashinfer/utils.py | 6 ++++++ 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/flashinfer/sampling.py b/flashinfer/sampling.py index aca2d811e5..02979574af 100644 --- a/flashinfer/sampling.py +++ b/flashinfer/sampling.py @@ -24,18 +24,19 @@ from .utils import ( _get_cache_buf, device_support_pdl, + get_default_generators, register_custom_op, register_fake_op, ) def get_seed_and_offset( - increment: int, generator: Optional[torch.Generator] = None + increment: int, + generator: Optional[torch.Generator] = None, + device: Optional[torch.device] = None, ) -> Tuple[int, int]: if generator is None: - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - generator = torch.Generator(device=device) - generator.seed() + generator = get_default_generators(device) # add mutex if multi-trheading needed state = generator.get_state() seed, offset = state.view(torch.int64) @@ -101,7 +102,9 @@ def sampling_from_logits( out_dtype = indices.dtype if indices is not None else torch.int32 samples = torch.empty(batch_size, dtype=out_dtype, device=device) if seed is None or offset is None: - seed, offset = get_seed_and_offset(batch_size * logits.size(1), generator) + seed, offset = get_seed_and_offset( + batch_size * logits.size(1), generator, device + ) module.sampling_from_logits( logits, samples, @@ -140,7 +143,7 @@ def sampling_from_probs( out_dtype = indices.dtype if indices is not None else torch.int32 samples = torch.empty(batch_size, dtype=out_dtype, device=device) if seed is None or offset is None: - seed, offset = get_seed_and_offset(batch_size, generator) + seed, offset = get_seed_and_offset(batch_size, generator, device) module.sampling_from_probs( probs, samples, @@ -186,7 +189,7 @@ def top_p_sampling_from_probs( out_dtype = indices.dtype if indices is not None else torch.int32 samples = torch.empty(batch_size, dtype=out_dtype, device=device) if seed is None or offset is None: - seed, offset = get_seed_and_offset(batch_size * 32, generator) + seed, offset = get_seed_and_offset(batch_size * 32, generator, device) module.top_p_sampling_from_probs( probs, samples, @@ -233,7 +236,7 @@ def top_k_sampling_from_probs( out_dtype = indices.dtype if indices is not None else torch.int32 samples = torch.empty(batch_size, dtype=out_dtype, device=device) if seed is None or offset is None: - seed, offset = get_seed_and_offset(batch_size * 32, generator) + seed, offset = get_seed_and_offset(batch_size * 32, generator, device) module.top_k_sampling_from_probs( probs, samples, @@ -282,7 +285,7 @@ def min_p_sampling_from_probs( out_dtype = indices.dtype if indices is not None else torch.int32 samples = torch.empty(batch_size, dtype=out_dtype, device=device) if seed is None or offset is None: - seed, offset = get_seed_and_offset(batch_size, generator) + seed, offset = get_seed_and_offset(batch_size, generator, device) module.min_p_sampling_from_probs( probs, samples, @@ -320,7 +323,7 @@ def top_k_top_p_sampling_from_probs( out_dtype = indices.dtype if indices is not None else torch.int32 samples = torch.empty(batch_size, dtype=out_dtype, device=device) if seed is None or offset is None: - seed, offset = get_seed_and_offset(batch_size * 32, generator) + seed, offset = get_seed_and_offset(batch_size * 32, generator, device) module.top_k_top_p_sampling_from_probs( probs, samples, @@ -481,7 +484,7 @@ def chain_speculative_sampling( output_token_ids = torch.empty((b, n + 1), dtype=torch.int32, device=device) if seed is None or offset is None: seed, offset = get_seed_and_offset( - draft_probs.size(0) * (draft_probs.size(1) + 1), generator + draft_probs.size(0) * (draft_probs.size(1) + 1), generator, device ) module.chain_speculative_sampling( draft_probs, diff --git a/flashinfer/utils.py b/flashinfer/utils.py index 35861b2507..e2e820f221 100644 --- a/flashinfer/utils.py +++ b/flashinfer/utils.py @@ -1182,3 +1182,9 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +@functools.cache +def get_default_generators(device: torch.device): + torch.cuda.init() + return torch.cuda.default_generators[device.index]