diff --git a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py index 664f1aaca88..cc262699d82 100644 --- a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py +++ b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py @@ -83,7 +83,11 @@ def _require_matcher_advance(self, llm_req: LlmRequest) -> bool: if llm_req.guided_decoding_params is None: return False if llm_req.py_is_draft: - return True + if llm_req.is_context_init_state and llm_req.is_last_context_chunk: + return True + if llm_req.is_generation_in_progress_state: + return True + return False # The request is in a generation forward step. return llm_req.is_generation_in_progress_state @@ -189,7 +193,8 @@ def execute(self, batched_bitmask.append(self.bitmask[slot, i]) offset += len(llm_req.py_draft_tokens) + 1 - assert offset == logits.size(0) + # Dummy logits may exist for CUDA graph dummy requests. + assert offset <= logits.size(0) if len(batched_logits) > 0: torch.ops.trtllm.logits_bitmask(batched_logits, batched_bitmask) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 548b020887f..1cae8fcb985 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -325,6 +325,7 @@ def test_guided_decoding_4gpus(self, backend: str, mocker): def test_guided_decoding_with_eagle3(self, backend: str, mocker): mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8) + cuda_graph_config = CudaGraphConfig(enable_padding=True) spec_config = EagleDecodingConfig( max_draft_len=3, speculative_model_dir= @@ -333,6 +334,8 @@ def test_guided_decoding_with_eagle3(self, backend: str, mocker): llm = LLM(self.MODEL_PATH, guided_decoding_backend=backend, kv_cache_config=kv_cache_config, + cuda_graph_config=cuda_graph_config, + enable_chunked_prefill=True, speculative_config=spec_config, disable_overlap_scheduler=True) with llm: @@ -344,11 +347,14 @@ def test_guided_decoding_with_eagle3(self, backend: str, mocker): def test_guided_decoding_with_ngram(self, backend: str, mocker): mocker.patch.dict(os.environ, {"TRTLLM_XGUIDANCE_LENIENT": "1"}) kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.8) + cuda_graph_config = CudaGraphConfig(enable_padding=True) spec_config = NGramDecodingConfig(max_draft_len=3, max_matching_ngram_size=3) llm = LLM(self.MODEL_PATH, guided_decoding_backend=backend, kv_cache_config=kv_cache_config, + cuda_graph_config=cuda_graph_config, + enable_chunked_prefill=True, speculative_config=spec_config, disable_overlap_scheduler=True) with llm: