@@ -88,5 +88,62 @@ def mock_should_use_spec_decode(self, requests, max_batch_size,
8888 assert text_spec == text_ref
8989
9090
91+ def test_should_use_spec_decode ():
92+ from tensorrt_llm ._torch .speculative .drafter import Drafter
93+
94+ class _DummyDrafter (Drafter ):
95+
96+ def prepare_draft_tokens (self ,
97+ scheduled_requests ,
98+ resource_manager = None ) -> None :
99+ return
100+
101+ drafter = _DummyDrafter (max_concurrency = 6 )
102+
103+ # Compare min(len(requests), max_batch_size, token_cap) with max_concurrency
104+
105+ # Small active_requests ON case: num_effective_requests = min(5, 8, very_large) = 5 <= 6 → True
106+ active_requests = [object ()] * 5
107+ assert drafter .should_use_spec_decode (active_requests ,
108+ max_batch_size = 8 ,
109+ max_num_tokens = 4096 * 8 ,
110+ max_draft_len = 4 ) is True
111+
112+ # Small batch size ON case: num_effective_requests = min(12, 5, very_large) = 5 <= 6 → True
113+ active_requests = [object ()] * 12
114+ assert drafter .should_use_spec_decode (active_requests ,
115+ max_batch_size = 5 ,
116+ max_num_tokens = 4096 * 8 ,
117+ max_draft_len = 4 ) is True
118+
119+ # Small token budget ON case: token_cap = 28 // (1+4) = 5 → min(8, 12, 5) = 5 <= 6 → True
120+ active_requests = [object ()] * 12
121+ assert drafter .should_use_spec_decode (active_requests ,
122+ max_batch_size = 8 ,
123+ max_num_tokens = 28 ,
124+ max_draft_len = 4 ) is True
125+
126+ # Generic OFF case: num_effective_requests = min(12, 8, very_large) = 8 > 6 → False
127+ active_requests = [object ()] * 12
128+ assert drafter .should_use_spec_decode (active_requests ,
129+ max_batch_size = 8 ,
130+ max_num_tokens = 4096 * 8 ,
131+ max_draft_len = 4 ) is False
132+
133+ # Edge case - None active requests OFF case: num_effective_requests = min(0, 8, very_large) = 0 <= 6 → False
134+ active_requests = []
135+ assert drafter .should_use_spec_decode (active_requests ,
136+ max_batch_size = 8 ,
137+ max_num_tokens = 4096 * 8 ,
138+ max_draft_len = 4 ) is False
139+
140+ # Edge case - Token cap equals 0 OFF case: token_cap = 4 // (1+4) = 0 → min(12, 8, 0) = 0 <= 6 → False
141+ active_requests = [object ()] * 12
142+ assert drafter .should_use_spec_decode (active_requests ,
143+ max_batch_size = 8 ,
144+ max_num_tokens = 4 ,
145+ max_draft_len = 4 ) is False
146+
147+
91148if __name__ == "__main__" :
92149 unittest .main ()
0 commit comments