Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 58 additions & 20 deletions tensorrt_llm/_torch/pyexecutor/sampling_utils_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,22 +383,17 @@ def __init__(self, top_k: torch.Tensor, top_p: torch.Tensor, temperature: torch.
def from_strategies(
cls, strategies: list[Strategy], cuda_device: torch.device
) -> "_StrategyImpls.TopKTopPSampleOnly":
assert all(strat[0] in ["top_k_top_p", "top_k"] for strat in strategies)
narrowed_strats = cast(list[TopKTopP | TopK], strategies)
top_k_list = []
top_p_list = []
temperature_list = []
for strat in narrowed_strats:
top_k_list.append(strat[1])
if strat[0] == "top_k_top_p":
top_p_list.append(strat[2])
temperature_list.append(strat[3])
else:
top_p_list.append(1.0)
temperature_list.append(strat[2])
top_k = cls._make_tensor(top_k_list, torch.int32, cuda_device)
top_p = cls._make_tensor(top_p_list, torch.float32, cuda_device)
temperature = cls._make_tensor(temperature_list, torch.float32, cuda_device)
assert all(strat[0] == "top_k_top_p" for strat in strategies)
narrowed_strats = cast(list[TopKTopP], strategies)
top_k = cls._make_tensor(
[strat[1] for strat in narrowed_strats], torch.int32, cuda_device
)
top_p = cls._make_tensor(
[strat[2] for strat in narrowed_strats], torch.float32, cuda_device
)
temperature = cls._make_tensor(
[strat[3] for strat in narrowed_strats], torch.float32, cuda_device
)
return cls(top_k, top_p, temperature)

@override
Expand Down Expand Up @@ -427,6 +422,50 @@ def sample(
generator=generator,
), None

class TopKSampleOnly(StrategyImplSampleOnly):
def __init__(self, top_k: torch.Tensor, temperature: torch.Tensor):
self._top_k = top_k
self._temperature = temperature

@override
@classmethod
def from_strategies(
cls, strategies: list[Strategy], cuda_device: torch.device
) -> "_StrategyImpls.TopKSampleOnly":
assert all(strat[0] == "top_k" for strat in strategies)
narrowed_strats = cast(list[TopK], strategies)
top_k = cls._make_tensor(
[strat[1] for strat in narrowed_strats], torch.int32, cuda_device
)
temperature = cls._make_tensor(
[strat[2] for strat in narrowed_strats], torch.float32, cuda_device
)
return cls(top_k, temperature)

@override
def sample(
self,
logits: torch.Tensor,
*,
group_logit_indices: Optional[torch.Tensor] = None,
generator: Optional[torch.Generator] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
probs = self._prepare_probs_with_temperature(
logits, group_logit_indices, self._temperature
)
return flashinfer.sampling.top_k_sampling_from_probs(
probs,
top_k=self._top_k,
# NB: Leveraging 'indices' would require applying temperature+softmax before batching,
# because 'flashinfer.sampling.softmax' has no 'indices' argument; but that would
# compute unnecessarily softmax also for situations allowing
# flashinfer.sampling...._sampling_from_logits.
# indices=group_logit_indices,
deterministic=True,
check_nan=self._flashinfer_check_nans(probs),
generator=generator,
), None

class TopPSampleOnly(StrategyImplSampleOnly):
def __init__(self, top_p: torch.Tensor, temperature: torch.Tensor):
self._top_p = top_p
Expand Down Expand Up @@ -540,10 +579,9 @@ def strategy_grouping_key(strategy: Strategy, return_probs: bool) -> STRATEGY_KE
match strategy:
case ("top_p", _, _):
return _StrategyImpls.TopPSampleOnly
case ("top_k_top_p", _, _, _) | ("top_k", _, _):
# NB: There is no TopKSampleOnly, because FlashInfer only provides
# top_k_sampling_from_probs (not top_k_sampling_from_logits),
# which is likely slower than top_k_top_p_sampling_from_logits.
case ("top_k", _, _):
return _StrategyImpls.TopKSampleOnly
case ("top_k_top_p", _, _, _):
return _StrategyImpls.TopKTopPSampleOnly
case ("temperature", _):
return _StrategyImpls.TemperatureOnlySampleOnly
Expand Down
31 changes: 31 additions & 0 deletions tests/unittest/_torch/sampler/test_torch_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,6 +1601,37 @@ def _mock_flashinfer_from_logits(

patch_ctx.setattr(flashinfer.sampling, "sampling_from_logits", _mock_flashinfer_from_logits)

def _mock_flashinfer_top_k(
probs: torch.Tensor,
*,
top_k: torch.Tensor,
deterministic: bool,
check_nan: bool,
generator: torch.Generator,
) -> torch.Tensor:
assert deterministic
assert not check_nan, "check_nan syncs"
assert generator is sampler.get_generator(probs.device)
nonlocal mock_sampling_log
new_entries = [
TestBatchedSampling._MockSamplingLogEntry(
probs=probs[row_idx],
sampling_params=TestBatchedSampling._TorchUtilsSamplingParams(
top_k=top_k[row_idx],
top_p=None,
temperature=None,
),
)
for row_idx in range(probs.size(0))
]
mock_tokens = torch.arange(
len(mock_sampling_log), len(mock_sampling_log) + len(new_entries)
)
mock_sampling_log += new_entries
return mock_tokens

patch_ctx.setattr(flashinfer.sampling, "top_k_sampling_from_probs", _mock_flashinfer_top_k)

def _mock_flashinfer_top_p(
probs: torch.Tensor,
*,
Expand Down
Loading