diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 4b04c99644b..3f27c34e864 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -828,7 +828,7 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int, max_batch_size: int, speculative_config: SpeculativeConfig, max_beam_width: int, - disable_flash_infer_sampling: bool): + disable_flashinfer_sampling: bool): max_num_sequences = max_batch_size * mapping.pp_size max_draft_len = (0 if speculative_config is None else speculative_config.max_draft_len) @@ -841,7 +841,7 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int, max_total_draft_tokens=max_total_draft_tokens, max_num_sequences=max_num_sequences, max_beam_width=max_beam_width, - disable_flash_infer_sampling=disable_flash_infer_sampling, + disable_flashinfer_sampling=disable_flashinfer_sampling, ) @@ -857,7 +857,7 @@ def instantiate_sampler( speculative_config: SpeculativeConfig, decoding_config: trtllm.DecodingConfig, kv_cache_config: KvCacheConfig, - disable_flash_infer_sampling: bool, + disable_flashinfer_sampling: bool, ): sampler_args = create_torch_sampler_args( mapping, @@ -865,7 +865,7 @@ def instantiate_sampler( max_batch_size=max_batch_size, speculative_config=speculative_config, max_beam_width=max_beam_width, - disable_flash_infer_sampling=disable_flash_infer_sampling, + disable_flashinfer_sampling=disable_flashinfer_sampling, ) decoding_mode = get_decoding_mode(decoding_config=decoding_config, max_beam_width=max_beam_width) diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 923599562ce..3605369a7f6 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -533,7 +533,7 @@ def drafting_loop_wrapper(model): speculative_config=spec_config, decoding_config=decoding_config, kv_cache_config=kv_cache_config, - disable_flash_infer_sampling=llm_args._disable_flash_infer_sampling, + disable_flashinfer_sampling=llm_args.disable_flashinfer_sampling, ) logger.info(f"Using Sampler: {type(sampler).__name__}") diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 2dea9a25537..5feacbe6ca4 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -616,7 +616,7 @@ class Args: max_num_sequences: int max_beam_width: int max_total_draft_tokens: int - disable_flash_infer_sampling: bool = False + disable_flashinfer_sampling: bool = False def __init__(self, args: Args): self.max_seq_len = args.max_seq_len @@ -652,7 +652,7 @@ def __init__(self, args: Args): } self._grouped_sampler_cls: Type[GroupedStrategySampler] - if IS_FLASHINFER_AVAILABLE and not args.disable_flash_infer_sampling: + if IS_FLASHINFER_AVAILABLE and not args.disable_flashinfer_sampling: from .sampling_utils_flashinfer import FlashInferGroupedStrategySampler self._grouped_sampler_cls = FlashInferGroupedStrategySampler diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 58b05edad2b..337127fa4ae 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -2707,8 +2707,12 @@ class TorchLlmArgs(BaseLlmArgs): # PrivateVars _quant_config: Optional[QuantConfig] = PrivateAttr(default=None) - _disable_flash_infer_sampling: bool = PrivateAttr(default=True) - """Unless this is set to False, FlashInfer.sampling is not used, even if available.""" + disable_flashinfer_sampling: bool = Field( + default=True, + description= + "Disable the use of FlashInfer.sampling. This option is likely to be removed in the future.", + status="prototype", + ) @property def quant_config(self) -> QuantConfig: diff --git a/tests/unittest/_torch/sampler/test_torch_sampler.py b/tests/unittest/_torch/sampler/test_torch_sampler.py index cc2e52904ce..a782e8eb400 100644 --- a/tests/unittest/_torch/sampler/test_torch_sampler.py +++ b/tests/unittest/_torch/sampler/test_torch_sampler.py @@ -1077,7 +1077,7 @@ def _build_sampler( max_beam_width=1, # currently the only supported value max_num_sequences=num_seq_slots, max_total_draft_tokens=max_draft_len, - disable_flash_infer_sampling=(not use_flashinfer), + disable_flashinfer_sampling=(not use_flashinfer), ) ) diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index c08ec37c119..18308ab290e 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -107,6 +107,10 @@ methods: annotation: bool default: False status: beta + disable_flashinfer_sampling: + annotation: bool + default: False + status: prototype moe_config: annotation: tensorrt_llm.llmapi.llm_args.MoeConfig status: beta