Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions tensorrt_llm/_torch/pyexecutor/executor_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def _fetch_and_process_requests(
new_requests)

# Validate and filter requests
new_requests = self._validate_and_filter_requests(new_requests)
new_requests = self._handle_special_queue_items(new_requests)

# Attach Python objects to requests
if py_request_objects and (self.dist.tp_size > 1
Expand Down Expand Up @@ -482,11 +482,11 @@ def _handle_request_broadcasting(self,

return new_requests, py_request_objects

def _validate_and_filter_requests(
def _handle_special_queue_items(
self,
new_requests: List[RequestQueueItem]) -> List[RequestQueueItem]:
"""Validate and filter requests, handling shutdown signals."""
valid_new_requests = []
"""Handle special signals."""
accepted_new_requests = []
for idx, req_item in enumerate(new_requests):
if req_item.is_shutdown_request:
self.is_shutdown = True
Expand All @@ -499,17 +499,9 @@ def _validate_and_filter_requests(
self.request_accumulated.extend(new_requests[idx + 1:])
break
else:
valid_new_requests.append(req_item)
accepted_new_requests.append(req_item)

# Check beam width validation
for req_item in valid_new_requests:
if req_item.request and hasattr(req_item.request,
'sampling_config'):
assert req_item.request.sampling_config.beam_width == self.max_beam_width, \
f"Request beam width {req_item.request.sampling_config.beam_width} " \
f"is not equal to max_beam_width {self.max_beam_width}. This is not supported!"

return valid_new_requests
return accepted_new_requests

def _balance_requests_across_ranks(
self, new_requests: List[RequestQueueItem],
Expand Down
10 changes: 10 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1600,6 +1600,16 @@ def _forward_step_inter_pp(self, scheduled_batch) -> SampleState:
)

def _validate_request(self, request: LlmRequest):
# Validate beam width
sampling_config = request.sampling_config
if sampling_config is not None:
if sampling_config.beam_width != self.max_beam_width:
raise ValueError(
f"Request beam width {sampling_config.beam_width} "
f"is not equal to max_beam_width {self.max_beam_width}. This is not supported!"
)

# Check token ID ranges
if isinstance(self.model_engine.model, DecoderModelForCausalLM):
# Only skip token‐range checks for Llama4 when the request has multimodal data
from ..models.modeling_llama import Llama4ForConditionalGeneration
Expand Down
6 changes: 3 additions & 3 deletions tests/unittest/_torch/executor/test_executor_request_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,8 +475,8 @@ def test_get_from_waiting_queue_edge_cases(executor_queue, queue_size,
assert len(executor_queue.waiting_queue) == expected_remaining


def test_validate_and_filter_requests(executor_queue):
"""Test request validation and filtering."""
def test_handle_special_queue_items(executor_queue):
"""Test special queue item handling."""
# Create a mock request without sampling_config to avoid beam validation
mock_request = Mock()
delattr(mock_request, 'sampling_config') if hasattr(
Expand All @@ -488,7 +488,7 @@ def test_validate_and_filter_requests(executor_queue):

requests = [normal_req, cancel_req, shutdown_req]

valid_requests = executor_queue._validate_and_filter_requests(requests)
valid_requests = executor_queue._handle_special_queue_items(requests)

assert len(valid_requests) == 1
assert valid_requests[0] == normal_req
Expand Down
139 changes: 129 additions & 10 deletions tests/unittest/_torch/sampler/test_beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import pytest
import torch
from transformers.configuration_utils import PretrainedConfig
from utils.llm_data import llm_models_root
from utils.util import force_ampere

from tensorrt_llm import LLM, SamplingParams
from tensorrt_llm._torch.attention_backend.interface import AttentionMetadata
Expand All @@ -31,6 +33,7 @@
from tensorrt_llm._torch.models.modeling_utils import (
ModelConfig, register_auto_model, register_checkpoint_weight_loader,
register_config_loader)
from tensorrt_llm.executor import RequestError
from tensorrt_llm.executor.result import CompletionOutput, GenerationResult
from tensorrt_llm.llmapi import CudaGraphConfig, KvCacheConfig

Expand Down Expand Up @@ -263,11 +266,21 @@ def fixed_params():


@pytest.fixture(scope="module")
def llm(fixed_params, input_prompts):
def model_kwargs(fixed_params) -> dict[str, Any]:
assert fixed_params[
"max_beam_width"] == 2, "This test only works for a beam width of 2"
return LLM(
return dict(
model=_pl.Path("dummy_path"),
checkpoint_loader=HfCheckpointLoader(
weight_loader=DummyWeightLoader(),
config_loader=DummyConfigLoader(),
),
)


def _build_llm(fixed_params, input_prompts, model_kwargs):
return LLM(
**model_kwargs,
kv_cache_config=KvCacheConfig(max_tokens=10000),
max_batch_size=fixed_params["max_beam_width"] * len(
input_prompts
Expand All @@ -276,16 +289,18 @@ def llm(fixed_params, input_prompts):
max_beam_width=fixed_params["max_beam_width"],
disable_overlap_scheduler=True,
cuda_graph_config=None,
checkpoint_loader=HfCheckpointLoader(weight_loader=DummyWeightLoader(),
config_loader=DummyConfigLoader()))
)


@pytest.fixture(scope="module")
def llm_cuda_graph(fixed_params, input_prompts):
assert fixed_params[
"max_beam_width"] == 2, "This test only works for a beam width of 2"
def llm(fixed_params, input_prompts, model_kwargs):
return _build_llm(fixed_params, input_prompts, model_kwargs)


@pytest.fixture(scope="module")
def llm_cuda_graph(fixed_params, input_prompts, model_kwargs):
return LLM(
model=_pl.Path("dummy_path"),
**model_kwargs,
kv_cache_config=KvCacheConfig(max_tokens=10000),
max_batch_size=fixed_params["max_beam_width"] * len(
input_prompts
Expand All @@ -295,8 +310,7 @@ def llm_cuda_graph(fixed_params, input_prompts):
disable_overlap_scheduler=False,
cuda_graph_config=CudaGraphConfig(batch_sizes=[1, 2, 4, 8],
enable_padding=True),
checkpoint_loader=HfCheckpointLoader(weight_loader=DummyWeightLoader(),
config_loader=DummyConfigLoader()))
)


def check_generation_logits(beam: CompletionOutput,
Expand Down Expand Up @@ -473,5 +487,110 @@ def test_beam_search_output_shapes_cuda_graph_and_overlap(
sampling_params)


@force_ampere # Save H100 resource
class TestParameterValidation:
"""Ensure that unsupported request parameters do not crash/hang the engine."""

@pytest.fixture(scope="module")
@staticmethod
def fixed_params():
return {"max_tokens": 8, "max_beam_width": 4}

@pytest.fixture(scope="module")
@staticmethod
def model_kwargs() -> dict[str, Any]:
root = llm_models_root()
assert root is not None
return dict(model=root / "llama-models-v2" /
"TinyLlama-1.1B-Chat-v1.0", )

# NB: Class-level fixture overrides do not work without this
@pytest.fixture(scope="module")
@staticmethod
def llm(fixed_params, input_prompts, model_kwargs):
return _build_llm(fixed_params, input_prompts, model_kwargs)

def _check_engine_responds(self, llm: LLM, input_prompts: list[str],
fixed_params: dict):
_ = llm.generate(input_prompts,
sampling_params=SamplingParams(
max_tokens=fixed_params["max_tokens"],
n=1,
best_of=fixed_params["max_beam_width"],
use_beam_search=True,
end_id=-1,
))

@pytest.mark.timeout(120)
@pytest.mark.threadleak(enabled=False)
def test_use_beam_search_false(
self,
llm: LLM,
input_prompts: list[str],
fixed_params: dict,
):
assert fixed_params["max_beam_width"] > 2
with pytest.raises(
ValueError,
match=
".*Greedy decoding in the LLM API does not allow multiple returns.*"
):
_ = llm.generate(input_prompts,
sampling_params=SamplingParams(
max_tokens=fixed_params["max_tokens"],
n=1,
best_of=fixed_params["max_beam_width"],
use_beam_search=False,
end_id=-1,
))
self._check_engine_responds(llm, input_prompts, fixed_params)

@pytest.mark.timeout(120)
@pytest.mark.threadleak(enabled=False)
def test_use_beam_search_ommitted(
self,
llm: LLM,
input_prompts: list[str],
fixed_params: dict,
):
assert fixed_params["max_beam_width"] > 2
with pytest.raises(
ValueError,
match=
".*Greedy decoding in the LLM API does not allow multiple returns.*"
):
_ = llm.generate(input_prompts,
sampling_params=SamplingParams(
max_tokens=fixed_params["max_tokens"],
n=1,
best_of=fixed_params["max_beam_width"],
end_id=-1,
))
self._check_engine_responds(llm, input_prompts, fixed_params)

@pytest.mark.timeout(120)
@pytest.mark.threadleak(enabled=False)
def test_smaller_beam_width(
self,
llm: LLM,
input_prompts: list[str],
fixed_params: dict,
):
assert fixed_params["max_beam_width"] > 2
with pytest.raises(
RequestError,
match=".*Request beam width 2 is not equal to max_beam_width 4*"
):
_ = llm.generate(input_prompts,
sampling_params=SamplingParams(
max_tokens=fixed_params["max_tokens"],
n=1,
best_of=2,
use_beam_search=True,
end_id=-1,
))
self._check_engine_responds(llm, input_prompts, fixed_params)


if __name__ == "__main__":
pytest.main([__file__])