Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,8 @@ def __init__(
if mapping.has_pp():
init_pp_comm(mapping)
self.dist = dist
ExpertStatistic.create(self.dist.rank)
if dist is not None:
ExpertStatistic.create(self.dist.rank)
self.pytorch_backend_config = pytorch_backend_config
self.spec_config = spec_config
self.is_spec_decode = spec_config is not None
Expand Down
11 changes: 8 additions & 3 deletions tests/unittest/_torch/test_pytorch_model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,9 @@ def create_model_engine_and_kvcache(config: PyTorchConfig = None):

config = config if config else PyTorchConfig(
use_cuda_graph=True, cuda_graph_padding_enabled=True)
config.cuda_graph_batch_sizes = [
1, 2, 4, 8, 16, 32, 64, 128
] if config.cuda_graph_batch_sizes is None else config.cuda_graph_batch_sizes
test_batches = (5, 13)
for batch_size in test_batches:
assert batch_size not in config.cuda_graph_batch_sizes
Expand Down Expand Up @@ -153,6 +156,7 @@ def test_pad_generation_requests(self) -> None:
batch.context_requests = []
batch.generation_requests = requests
pages_before = kv_cache_manager.get_num_free_blocks()
new_dummy_block = 1 if model_engine.cuda_graph_dummy_request is None else 0
with model_engine._maybe_pad_batch(
batch, kv_cache_manager) as padded_batch:
if batch_size < 8 and max_seq_len < 25:
Expand All @@ -165,8 +169,9 @@ def test_pad_generation_requests(self) -> None:
# The seqlen check makes sure we don't exceed the KV cache memory
# budget.
self.assertIs(batch, padded_batch)
self.assertEqual(kv_cache_manager.get_num_free_blocks(),
pages_before)
self.assertEqual(
kv_cache_manager.get_num_free_blocks() + new_dummy_block,
pages_before)

kv_cache_manager.shutdown()

Expand Down Expand Up @@ -205,7 +210,7 @@ def test_position_id_preparation(self):

model_engine.forward(batch, resource_manager)
expected_gen_pos_id = torch.tensor([prompt_len],
dtype=torch.int64,
dtype=torch.int32,
device='cuda').unsqueeze(0)
torch.testing.assert_close(model_engine.model.recorded_position_ids,
expected_gen_pos_id,
Expand Down