Skip to content

Commit c970d04

Browse files
committed
Add test cases for should_use_spec_decode.
Signed-off-by: Zheyu Fu <[email protected]>
1 parent 06875c2 commit c970d04

File tree

3 files changed

+65
-5
lines changed

3 files changed

+65
-5
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -923,7 +923,7 @@ def _prepare_and_schedule_batch(self):
923923
self.use_spec_decode = self.drafter.should_use_spec_decode(
924924
self.active_requests, self.max_batch_size,
925925
self.model_engine.max_num_tokens,
926-
self.model_engine.max_draft_len)
926+
self.model_engine.spec_config.max_draft_len)
927927
self.model_engine.enable_spec_decode = self.use_spec_decode
928928
# If speculation is off, this function sets py_draft_tokens to None
929929
# for all active requests. If it's on, we initialize py_draft_tokens

tensorrt_llm/_torch/speculative/drafter.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,13 @@ def should_use_spec_decode(self, requests: List[LlmRequest],
3535
assumes that speculation is always on if max_concurrency
3636
is not specified by the user's spec config.
3737
"""
38+
39+
# Inputs validated upstream: max_batch_size>0, max_num_tokens>0, max_draft_len>0
40+
3841
if self.max_concurrency is None:
3942
return True
4043

41-
tokens_per_request = 1 + max_draft_len
42-
token_cap = max_num_tokens // tokens_per_request
43-
num_effective_requests = min(max_batch_size, len(requests), token_cap)
44-
return num_effective_requests <= self.max_concurrency
44+
token_cap = max_num_tokens // (1 + max_draft_len)
45+
num_effective_requests = min(len(requests), max_batch_size, token_cap)
46+
47+
return 0 < num_effective_requests <= self.max_concurrency

tests/unittest/_torch/speculative/test_dynamic_spec_decode.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,5 +88,62 @@ def mock_should_use_spec_decode(self, requests, max_batch_size,
8888
assert text_spec == text_ref
8989

9090

91+
def test_should_use_spec_decode():
92+
from tensorrt_llm._torch.speculative.drafter import Drafter
93+
94+
class _DummyDrafter(Drafter):
95+
96+
def prepare_draft_tokens(self,
97+
scheduled_requests,
98+
resource_manager=None) -> None:
99+
return
100+
101+
drafter = _DummyDrafter(max_concurrency=6)
102+
103+
# Compare min(len(requests), max_batch_size, token_cap) with max_concurrency
104+
105+
# Small active_requests ON case: num_effective_requests = min(5, 8, very_large) = 5 <= 6 → True
106+
active_requests = [object()] * 5
107+
assert drafter.should_use_spec_decode(active_requests,
108+
max_batch_size=8,
109+
max_num_tokens=4096 * 8,
110+
max_draft_len=4) is True
111+
112+
# Small batch size ON case: num_effective_requests = min(12, 5, very_large) = 5 <= 6 → True
113+
active_requests = [object()] * 12
114+
assert drafter.should_use_spec_decode(active_requests,
115+
max_batch_size=5,
116+
max_num_tokens=4096 * 8,
117+
max_draft_len=4) is True
118+
119+
# Small token budget ON case: token_cap = 28 // (1+4) = 5 → min(8, 12, 5) = 5 <= 6 → True
120+
active_requests = [object()] * 12
121+
assert drafter.should_use_spec_decode(active_requests,
122+
max_batch_size=8,
123+
max_num_tokens=28,
124+
max_draft_len=4) is True
125+
126+
# Generic OFF case: num_effective_requests = min(12, 8, very_large) = 8 > 6 → False
127+
active_requests = [object()] * 12
128+
assert drafter.should_use_spec_decode(active_requests,
129+
max_batch_size=8,
130+
max_num_tokens=4096 * 8,
131+
max_draft_len=4) is False
132+
133+
# Edge case - None active requests OFF case: num_effective_requests = min(0, 8, very_large) = 0 <= 6 → False
134+
active_requests = []
135+
assert drafter.should_use_spec_decode(active_requests,
136+
max_batch_size=8,
137+
max_num_tokens=4096 * 8,
138+
max_draft_len=4) is False
139+
140+
# Edge case - Token cap equals 0 OFF case: token_cap = 4 // (1+4) = 0 → min(12, 8, 0) = 0 <= 6 → False
141+
active_requests = [object()] * 12
142+
assert drafter.should_use_spec_decode(active_requests,
143+
max_batch_size=8,
144+
max_num_tokens=4,
145+
max_draft_len=4) is False
146+
147+
91148
if __name__ == "__main__":
92149
unittest.main()

0 commit comments

Comments
 (0)