diff --git a/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py b/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py index f8ce56a1672..a45efc8d8e5 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py +++ b/tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py @@ -383,22 +383,17 @@ def __init__(self, top_k: torch.Tensor, top_p: torch.Tensor, temperature: torch. def from_strategies( cls, strategies: list[Strategy], cuda_device: torch.device ) -> "_StrategyImpls.TopKTopPSampleOnly": - assert all(strat[0] in ["top_k_top_p", "top_k"] for strat in strategies) - narrowed_strats = cast(list[TopKTopP | TopK], strategies) - top_k_list = [] - top_p_list = [] - temperature_list = [] - for strat in narrowed_strats: - top_k_list.append(strat[1]) - if strat[0] == "top_k_top_p": - top_p_list.append(strat[2]) - temperature_list.append(strat[3]) - else: - top_p_list.append(1.0) - temperature_list.append(strat[2]) - top_k = cls._make_tensor(top_k_list, torch.int32, cuda_device) - top_p = cls._make_tensor(top_p_list, torch.float32, cuda_device) - temperature = cls._make_tensor(temperature_list, torch.float32, cuda_device) + assert all(strat[0] == "top_k_top_p" for strat in strategies) + narrowed_strats = cast(list[TopKTopP], strategies) + top_k = cls._make_tensor( + [strat[1] for strat in narrowed_strats], torch.int32, cuda_device + ) + top_p = cls._make_tensor( + [strat[2] for strat in narrowed_strats], torch.float32, cuda_device + ) + temperature = cls._make_tensor( + [strat[3] for strat in narrowed_strats], torch.float32, cuda_device + ) return cls(top_k, top_p, temperature) @override @@ -427,6 +422,50 @@ def sample( generator=generator, ), None + class TopKSampleOnly(StrategyImplSampleOnly): + def __init__(self, top_k: torch.Tensor, temperature: torch.Tensor): + self._top_k = top_k + self._temperature = temperature + + @override + @classmethod + def from_strategies( + cls, strategies: list[Strategy], cuda_device: torch.device + ) -> "_StrategyImpls.TopKSampleOnly": + assert all(strat[0] == "top_k" for strat in strategies) + narrowed_strats = cast(list[TopK], strategies) + top_k = cls._make_tensor( + [strat[1] for strat in narrowed_strats], torch.int32, cuda_device + ) + temperature = cls._make_tensor( + [strat[2] for strat in narrowed_strats], torch.float32, cuda_device + ) + return cls(top_k, temperature) + + @override + def sample( + self, + logits: torch.Tensor, + *, + group_logit_indices: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + probs = self._prepare_probs_with_temperature( + logits, group_logit_indices, self._temperature + ) + return flashinfer.sampling.top_k_sampling_from_probs( + probs, + top_k=self._top_k, + # NB: Leveraging 'indices' would require applying temperature+softmax before batching, + # because 'flashinfer.sampling.softmax' has no 'indices' argument; but that would + # compute unnecessarily softmax also for situations allowing + # flashinfer.sampling...._sampling_from_logits. + # indices=group_logit_indices, + deterministic=True, + check_nan=self._flashinfer_check_nans(probs), + generator=generator, + ), None + class TopPSampleOnly(StrategyImplSampleOnly): def __init__(self, top_p: torch.Tensor, temperature: torch.Tensor): self._top_p = top_p @@ -540,10 +579,9 @@ def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> STRATEGY_KE match strategy: case ("top_p", _, _): return _StrategyImpls.TopPSampleOnly - case ("top_k_top_p", _, _, _) | ("top_k", _, _): - # NB: There is no TopKSampleOnly, because FlashInfer only provides - # top_k_sampling_from_probs (not top_k_sampling_from_logits), - # which is likely slower than top_k_top_p_sampling_from_logits. + case ("top_k", _, _): + return _StrategyImpls.TopKSampleOnly + case ("top_k_top_p", _, _, _): return _StrategyImpls.TopKTopPSampleOnly case ("temperature", _): return _StrategyImpls.TemperatureOnlySampleOnly diff --git a/tests/unittest/_torch/sampler/test_torch_sampler.py b/tests/unittest/_torch/sampler/test_torch_sampler.py index b3db660ed17..cc2e52904ce 100644 --- a/tests/unittest/_torch/sampler/test_torch_sampler.py +++ b/tests/unittest/_torch/sampler/test_torch_sampler.py @@ -1601,6 +1601,37 @@ def _mock_flashinfer_from_logits( patch_ctx.setattr(flashinfer.sampling, "sampling_from_logits", _mock_flashinfer_from_logits) + def _mock_flashinfer_top_k( + probs: torch.Tensor, + *, + top_k: torch.Tensor, + deterministic: bool, + check_nan: bool, + generator: torch.Generator, + ) -> torch.Tensor: + assert deterministic + assert not check_nan, "check_nan syncs" + assert generator is sampler.get_generator(probs.device) + nonlocal mock_sampling_log + new_entries = [ + TestBatchedSampling._MockSamplingLogEntry( + probs=probs[row_idx], + sampling_params=TestBatchedSampling._TorchUtilsSamplingParams( + top_k=top_k[row_idx], + top_p=None, + temperature=None, + ), + ) + for row_idx in range(probs.size(0)) + ] + mock_tokens = torch.arange( + len(mock_sampling_log), len(mock_sampling_log) + len(new_entries) + ) + mock_sampling_log += new_entries + return mock_tokens + + patch_ctx.setattr(flashinfer.sampling, "top_k_sampling_from_probs", _mock_flashinfer_top_k) + def _mock_flashinfer_top_p( probs: torch.Tensor, *,