Skip to content

Commit ad4d9d4

Browse files
committed
[https://nvbugs/5513423][fix] Correctly respect min_tokens in PyTorch workflow using TorchSampler
- Added `py_min_length` attribute to `LlmRequest` to store minimum length configuration. - Implemented `_apply_min_length_penalty` method in `TorchSampler` to adjust logits based on minimum length requirements. (Mimics PenaltyLayer) - Updated test case for `min_tokens` to reflect new maximum sequence length and output length constraints from the model Signed-off-by: Stefan Niebler <[email protected]>
1 parent c4abca3 commit ad4d9d4

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,7 @@ def __init__(
329329
self.py_prompt_len = self.prompt_len
330330
self.py_orig_prompt_len = self.orig_prompt_len
331331
self.py_max_new_tokens = self.max_new_tokens
332+
self.py_min_length = self.sampling_config.min_length
332333
self.py_batch_idx = None
333334
self.py_draft_pages_allocated = 0
334335
self.py_rewind_len = 0

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,20 @@ def _apply_embedding_bias(
690690

691691
return logits
692692

693+
@staticmethod
694+
def _apply_min_length_penalty(logits: torch.Tensor,
695+
requests: list[LlmRequest]):
696+
697+
if not any(
698+
r.py_min_length and r.max_beam_num_tokens < r.py_min_length[0]
699+
for r in requests):
700+
return logits
701+
logits = logits.clone()
702+
for index, r in enumerate(requests):
703+
if r.py_min_length and r.max_beam_num_tokens < r.py_min_length[0]:
704+
logits[index, [r.py_end_id]] = float('-inf')
705+
return logits
706+
693707
@staticmethod
694708
def _longest_stop_word_len(requests: Iterable[LlmRequest]) -> int:
695709
max_stop_word_len = 0
@@ -905,6 +919,7 @@ def _process_requests(self,
905919
raw_logits = model_outputs["logits"]
906920

907921
requests = scheduled_requests.all_requests()
922+
raw_logits = self._apply_min_length_penalty(raw_logits, requests)
908923
num_steps = [1 + get_draft_token_length(req) for req in requests]
909924
sum_steps = sum(num_steps)
910925
no_draft_tokens = len(requests) == sum_steps

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -861,15 +861,14 @@ def test_llm_with_proxy_error():
861861

862862

863863
@pytest.mark.part0
864-
@pytest.mark.xfail(reason="https://nvbugs/5513423")
865864
def test_min_tokens():
866865
"""Check min_tokens is respected."""
867866
llm = LLM(model=llama_model_path,
868867
kv_cache_config=global_kvcache_config,
869868
enable_mixed_sampler=True,
870-
max_seq_len=20000)
869+
max_seq_len=2048)
871870

872-
output_len = 5000
871+
output_len = 2000
873872
sampling_params = SamplingParams(max_tokens=output_len,
874873
min_tokens=output_len,
875874
temperature=1)

0 commit comments

Comments
 (0)