diff --git a/ATTRIBUTIONS-Python.md b/ATTRIBUTIONS-Python.md index 066af45b0bc..93abb598cc8 100644 --- a/ATTRIBUTIONS-Python.md +++ b/ATTRIBUTIONS-Python.md @@ -5261,7 +5261,7 @@ For more information, please refer to - `Tracker`: https://github.com/tox-dev/py-filelock/issues -## flashinfer-python (0.6.3) +## flashinfer-python (0.6.4) ### Licenses License: `Apache-2.0` diff --git a/requirements.txt b/requirements.txt index d2aa81843d8..14e8dc10f71 100644 --- a/requirements.txt +++ b/requirements.txt @@ -54,7 +54,7 @@ ordered-set peft patchelf einops -flashinfer-python==0.6.3 +flashinfer-python==0.6.4 opencv-python-headless xgrammar==0.1.25 llguidance==0.7.29 diff --git a/security_scanning/pyproject.toml b/security_scanning/pyproject.toml index f2c0dc15ab2..c2a79f88100 100644 --- a/security_scanning/pyproject.toml +++ b/security_scanning/pyproject.toml @@ -56,7 +56,7 @@ dependencies = [ "peft (>=0.18.1,<0.19.0)", "patchelf (>=0.17.2.4,<0.18.0.0)", "einops (>=0.8.2,<0.9.0)", - "flashinfer-python (==0.6.3)", + "flashinfer-python (==0.6.4)", "xgrammar (==0.1.25)", "llguidance (==0.7.29)", "jsonschema (>=4.26.0,<5.0.0)", diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 523a78129f0..e649794c70f 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -388,9 +388,9 @@ def __init__(self, use_separate_draft_kv_cache: bool = False): super().__init__() self.guided_decoder: Optional["CapturableGuidedDecoder"] = None self.force_num_accepted_tokens = get_force_num_accepted_tokens() - self.use_flashinfer = IS_FLASHINFER_AVAILABLE and flashinfer.__version__ >= "0.6.0" - self.seed = 0 - self.offset = 0 + self.use_flashinfer = IS_FLASHINFER_AVAILABLE and flashinfer.__version__ >= "0.6.4" + self.seed: Optional[torch.Tensor] = None + self.offset: Optional[torch.Tensor] = None self.use_separate_draft_kv_cache = use_separate_draft_kv_cache @property @@ -688,7 +688,16 @@ def _sample_tokens_for_batch( top_ps = spec_metadata.top_ps[:num_tokens] if self.use_flashinfer: + # Lazily initialize seed/offset tensors on correct device + if self.seed is None: + self.seed = torch.tensor([0], + dtype=torch.int64, + device=logits.device) + self.offset = torch.tensor([0], + dtype=torch.int64, + device=logits.device) self.seed += 1 + self.seed %= (2**31) sampled_tokens = sampling_batch_spec_dec_one_model( logits,