Skip to content

Commit 16f5f5e

Browse files
committed
[https://nvbugs/5513423][chore] Enhance min length penalty application in TorchSampler to work with speculative decoding
- Refactored `_apply_min_length_penalty` method to support step-wise application of minimum length penalties based on the number of steps per request. - Updated the test for `min_tokens` to include a parameterized approach for speculative decoding configurations. Signed-off-by: Stefan Niebler <[email protected]>
1 parent 261bd55 commit 16f5f5e

File tree

2 files changed

+49
-16
lines changed

2 files changed

+49
-16
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -655,17 +655,33 @@ def _apply_embedding_bias(
655655
return logits
656656

657657
@staticmethod
658+
@torch.inference_mode()
658659
def _apply_min_length_penalty(logits: torch.Tensor,
659-
requests: list[LlmRequest]):
660+
requests: list[LlmRequest],
661+
num_steps: list[int]) -> torch.Tensor:
662+
"""Inplace apply min_length_penalty to logits.
660663
661-
if not any(
662-
r.py_min_length and r.max_beam_num_tokens < r.py_min_length[0]
663-
for r in requests):
664-
return logits
665-
logits = logits.clone()
666-
for index, r in enumerate(requests):
667-
if r.py_min_length and r.max_beam_num_tokens < r.py_min_length[0]:
668-
logits[index, [r.py_end_id]] = float('-inf')
664+
Args:
665+
logits: The logits to apply min length penalty to
666+
requests: The requests to apply min length penalty to
667+
num_steps: The number of steps per request
668+
669+
Returns:
670+
The logits with min length penalty applied
671+
"""
672+
if any(r.py_min_length and r.max_beam_num_tokens < r.py_min_length[0]
673+
for r in requests):
674+
current_offset = 0
675+
for index, r in enumerate(requests):
676+
if r.py_min_length:
677+
for step in range(num_steps[index]):
678+
if r.max_beam_num_tokens + step < r.py_min_length[0]:
679+
logits[current_offset + step,
680+
r.py_end_id] = float('-inf')
681+
else:
682+
#early exit
683+
break
684+
current_offset += num_steps[index]
669685
return logits
670686

671687
def _process_requests(self,
@@ -696,8 +712,9 @@ def _process_requests(self,
696712
raw_logits = model_outputs["logits"]
697713

698714
requests = scheduled_requests.all_requests()
699-
raw_logits = self._apply_min_length_penalty(raw_logits, requests)
700715
num_steps = [1 + get_draft_token_length(req) for req in requests]
716+
raw_logits = self._apply_min_length_penalty(raw_logits, requests,
717+
num_steps)
701718
sum_steps = sum(num_steps)
702719
no_draft_tokens = len(requests) == sum_steps
703720
fast_path = not self.enable_mixed_sampler and no_draft_tokens and log_probs_host is None

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from tensorrt_llm import LLM
77
from tensorrt_llm.executor import GenerationExecutorWorker
88
from tensorrt_llm.llmapi import KvCacheConfig
9-
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
9+
from tensorrt_llm.llmapi.llm_args import NGramDecodingConfig, PeftCacheConfig
1010
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
1111
from tensorrt_llm.metrics import MetricNames
1212
from tensorrt_llm.sampling_params import SamplingParams
@@ -861,12 +861,28 @@ def test_llm_with_proxy_error():
861861

862862

863863
@pytest.mark.part0
864-
def test_min_tokens():
864+
@pytest.mark.parametrize("use_speculative", [True, False])
865+
def test_min_tokens(use_speculative: bool):
865866
"""Check min_tokens is respected."""
866-
llm = LLM(model=llama_model_path,
867-
kv_cache_config=global_kvcache_config,
868-
enable_mixed_sampler=True,
869-
max_seq_len=2048)
867+
llm_common_config = dict(
868+
model=llama_model_path,
869+
max_batch_size=2,
870+
kv_cache_config=global_kvcache_config,
871+
max_num_tokens=2048,
872+
enable_mixed_sampler=True,
873+
)
874+
875+
if use_speculative:
876+
spec_config = NGramDecodingConfig(
877+
max_draft_len=4,
878+
max_matching_ngram_size=2,
879+
is_keep_all=True,
880+
is_use_oldest=True,
881+
is_public_pool=True,
882+
)
883+
llm = LLM(**llm_common_config, speculative_config=spec_config)
884+
else:
885+
llm = LLM(**llm_common_config)
870886

871887
output_len = 2000
872888
sampling_params = SamplingParams(max_tokens=output_len,

0 commit comments

Comments
 (0)