Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
a4e6235
Make the should_use_spec_decode logic a bit smarter.
zheyuf Aug 21, 2025
06875c2
Merge branch 'NVIDIA:main' into should_use_spec
zheyuf Aug 21, 2025
c970d04
Add test cases for should_use_spec_decode.
zheyuf Aug 23, 2025
6e75d48
Follow CodeRabbit's comments.
zheyuf Aug 25, 2025
08b2040
Merge branch 'NVIDIA:main' into should_use_spec
zheyuf Aug 25, 2025
556438b
Merge from origin.
zheyuf Aug 25, 2025
f6dc52b
Minor change.
zheyuf Aug 25, 2025
9a28cda
Use assert <expr> directly.
zheyuf Aug 25, 2025
689569b
Make the should_use_spec_decode logic a bit smarter.
zheyuf Aug 21, 2025
9034d8d
Add test cases for should_use_spec_decode.
zheyuf Aug 23, 2025
8cde4f2
Follow CodeRabbit's comments.
zheyuf Aug 25, 2025
a9cf635
Minor change.
zheyuf Aug 25, 2025
d0adbbd
Use assert <expr> directly.
zheyuf Aug 25, 2025
22997ec
Merge branch 'main' into should_use_spec
mikeiovine Aug 27, 2025
5ba2eb4
Change max(0, max_draft_len) to max_draft_len
zheyuf Aug 29, 2025
a01bc8d
Merge remote-tracking branch 'origin/should_use_spec' into should_use…
zheyuf Aug 29, 2025
c134cf1
Merge branch 'main' into should_use_spec
zheyuf Aug 29, 2025
690e9cf
Merge branch 'main' into should_use_spec
zheyuf Aug 29, 2025
54c30fc
Merge branch 'main' into should_use_spec
zheyuf Aug 30, 2025
a896b91
Merge branch 'main' into should_use_spec
mikeiovine Sep 2, 2025
6381f9c
Merge branch 'main' into should_use_spec
zheyuf Sep 3, 2025
00d498a
Merge branch 'main' into should_use_spec
zheyuf Sep 4, 2025
2e9aee5
Merge branch 'main' into should_use_spec
zheyuf Sep 6, 2025
8c47cc5
Merge branch 'main' into should_use_spec
zheyuf Sep 8, 2025
5764651
Merge branch 'main' into should_use_spec
zheyuf Sep 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -886,7 +886,9 @@ def _prepare_and_schedule_batch(self):

if self.drafter is not None:
self.use_spec_decode = self.drafter.should_use_spec_decode(
self.active_requests)
self.active_requests, self.max_batch_size,
self.model_engine.max_num_tokens,
self.model_engine.spec_config.max_draft_len)
self.model_engine.enable_spec_decode = self.use_spec_decode
# If speculation is off, this function sets py_draft_tokens to None
# for all active requests. If it's on, we initialize py_draft_tokens
Expand Down
24 changes: 20 additions & 4 deletions tensorrt_llm/_torch/speculative/drafter.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,28 @@ def prepare_draft_tokens(
raise NotImplementedError

@final
def should_use_spec_decode(self, requests: List[LlmRequest]) -> bool:
def should_use_spec_decode(self, requests: List[LlmRequest],
max_batch_size: int, max_num_tokens: int,
max_draft_len: int) -> bool:
"""
You probably don't want to override this. ModelEngine
assumes that speculation is always on if max_concurrency
is not specified by the user's spec config.
"""
if self.max_concurrency is not None:
return len(requests) <= self.max_concurrency
return True

# Inputs typically validated upstream: max_batch_size>0, max_num_tokens>0, max_draft_len>=0

if self.max_concurrency is None:
return True

# Defensive guards; keep behavior explicit for zero/empty cases
if not requests or max_batch_size <= 0 or max_num_tokens <= 0:
return False

tokens_per_request = 1 + max(0, max_draft_len)
token_cap = max_num_tokens // tokens_per_request
if token_cap <= 0:
return False

num_effective_requests = min(len(requests), max_batch_size, token_cap)
return num_effective_requests <= self.max_concurrency
58 changes: 57 additions & 1 deletion tests/unittest/_torch/speculative/test_dynamic_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ def test_dynamic_spec_decode():
)

# Mock should_use_spec_decode to return True for first two calls, then False
def mock_should_use_spec_decode(self, requests):
def mock_should_use_spec_decode(self, requests, max_batch_size,
max_num_tokens, max_draft_len):
if not hasattr(mock_should_use_spec_decode, 'call_count'):
mock_should_use_spec_decode.call_count = 0
mock_should_use_spec_decode.call_count += 1
Expand Down Expand Up @@ -86,5 +87,60 @@ def mock_should_use_spec_decode(self, requests):
assert text_spec == text_ref


def test_should_use_spec_decode():
from tensorrt_llm._torch.speculative.drafter import Drafter

class _DummyDrafter(Drafter):

def prepare_draft_tokens(self,
scheduled_requests,
resource_manager=None) -> None:
return

drafter = _DummyDrafter(max_concurrency=6)

# Compare min(len(requests), max_batch_size, token_cap) with max_concurrency

# Small active_requests ON case: num_effective_requests = min(5, 8, very_large) = 5 <= 6 → True
active_requests = [object()] * 5
assert drafter.should_use_spec_decode(active_requests,
max_batch_size=8,
max_num_tokens=4096 * 8,
max_draft_len=4)

# Small batch size ON case: num_effective_requests = min(12, 5, very_large) = 5 <= 6 → True
active_requests = [object()] * 12
assert drafter.should_use_spec_decode(active_requests,
max_batch_size=5,
max_num_tokens=4096 * 8,
max_draft_len=4)

# Small token budget ON case: token_cap = 28 // (1+4) = 5 → min(8, 12, 5) = 5 <= 6 → True
active_requests = [object()] * 12
assert drafter.should_use_spec_decode(active_requests,
max_batch_size=8,
max_num_tokens=28,
max_draft_len=4)

# Generic OFF case: num_effective_requests = min(12, 8, very_large) = 8 > 6 → False
active_requests = [object()] * 12
assert not drafter.should_use_spec_decode(active_requests,
max_batch_size=8,
max_num_tokens=4096 * 8,
max_draft_len=4)

# Edge case - None active requests OFF case
active_requests = []
assert not drafter.should_use_spec_decode(active_requests,
max_batch_size=8,
max_num_tokens=4096 * 8,
max_draft_len=4)

# Edge case - Token cap equals 0 OFF case: token_cap = 4 // (1+4) = 0 → min(12, 8, 0) = 0 <= 6 → False
active_requests = [object()] * 12
assert not drafter.should_use_spec_decode(
active_requests, max_batch_size=8, max_num_tokens=4, max_draft_len=4)


if __name__ == "__main__":
unittest.main()