Skip to content

Commit 8aead22

Browse files
stniedcampora
andauthored
[https://nvbugs/5513423][fix] Correctly respect min_tokens in PyTorch Workflow (#7808)
Signed-off-by: Stefan Niebler <[email protected]> Co-authored-by: Daniel Cámpora <[email protected]>
1 parent 9dc7316 commit 8aead22

File tree

3 files changed

+56
-8
lines changed

3 files changed

+56
-8
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def __init__(
327327
self.py_prompt_len = self.prompt_len
328328
self.py_orig_prompt_len = self.orig_prompt_len
329329
self.py_max_new_tokens = self.max_new_tokens
330+
self.py_min_length = self.sampling_config.min_length
330331
self.py_batch_idx = None
331332
self.py_draft_pages_allocated = 0
332333
self.py_rewind_len = 0

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,36 @@ def _apply_embedding_bias(
657657

658658
return logits
659659

660+
@staticmethod
661+
@torch.inference_mode()
662+
def _apply_min_length_penalty(logits: torch.Tensor,
663+
requests: list[LlmRequest],
664+
num_steps: list[int]) -> torch.Tensor:
665+
"""Inplace apply min_length_penalty to logits.
666+
667+
Args:
668+
logits: The logits to apply min length penalty to
669+
requests: The requests to apply min length penalty to
670+
num_steps: The number of steps per request
671+
672+
Returns:
673+
The logits with min length penalty applied
674+
"""
675+
if any(r.py_min_length and r.max_beam_num_tokens < r.py_min_length[0]
676+
for r in requests):
677+
current_offset = 0
678+
for index, r in enumerate(requests):
679+
if r.py_min_length:
680+
for step in range(num_steps[index]):
681+
if r.max_beam_num_tokens + step < r.py_min_length[0]:
682+
logits[current_offset + step,
683+
r.py_end_id] = float('-inf')
684+
else:
685+
#early exit
686+
break
687+
current_offset += num_steps[index]
688+
return logits
689+
660690
def _process_requests(self,
661691
scheduled_requests: ScheduledRequests,
662692
model_outputs: dict[str, torch.Tensor],
@@ -686,6 +716,8 @@ def _process_requests(self,
686716

687717
requests = scheduled_requests.all_requests()
688718
num_steps = [1 + get_draft_token_length(req) for req in requests]
719+
raw_logits = self._apply_min_length_penalty(raw_logits, requests,
720+
num_steps)
689721
sum_steps = sum(num_steps)
690722
no_draft_tokens = len(requests) == sum_steps
691723
fast_path = not self.enable_mixed_sampler and no_draft_tokens and log_probs_host is None

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from tensorrt_llm import LLM
77
from tensorrt_llm.executor import GenerationExecutorWorker
88
from tensorrt_llm.llmapi import KvCacheConfig
9-
from tensorrt_llm.llmapi.llm_args import PeftCacheConfig
9+
from tensorrt_llm.llmapi.llm_args import NGramDecodingConfig, PeftCacheConfig
1010
from tensorrt_llm.llmapi.tokenizer import TransformersTokenizer
1111
from tensorrt_llm.metrics import MetricNames
1212
from tensorrt_llm.sampling_params import SamplingParams
@@ -861,15 +861,30 @@ def test_llm_with_proxy_error():
861861

862862

863863
@pytest.mark.part0
864-
@pytest.mark.xfail(reason="https://nvbugs/5513423")
865-
def test_min_tokens():
864+
@pytest.mark.parametrize("use_speculative", [True, False])
865+
def test_min_tokens(use_speculative: bool):
866866
"""Check min_tokens is respected."""
867-
llm = LLM(model=llama_model_path,
868-
kv_cache_config=global_kvcache_config,
869-
enable_mixed_sampler=True,
870-
max_seq_len=20000)
867+
llm_common_config = dict(
868+
model=llama_model_path,
869+
max_batch_size=2,
870+
kv_cache_config=global_kvcache_config,
871+
max_num_tokens=2048,
872+
enable_mixed_sampler=True,
873+
)
874+
875+
if use_speculative:
876+
spec_config = NGramDecodingConfig(
877+
max_draft_len=4,
878+
max_matching_ngram_size=2,
879+
is_keep_all=True,
880+
is_use_oldest=True,
881+
is_public_pool=True,
882+
)
883+
llm = LLM(**llm_common_config, speculative_config=spec_config)
884+
else:
885+
llm = LLM(**llm_common_config)
871886

872-
output_len = 5000
887+
output_len = 2000
873888
sampling_params = SamplingParams(max_tokens=output_len,
874889
min_tokens=output_len,
875890
temperature=1)

0 commit comments

Comments
 (0)