@@ -1183,6 +1183,12 @@ def sample_async(
1183
1183
model_outputs : dict [str , torch .Tensor ],
1184
1184
num_context_logits_prefix_sum : list [int ],
1185
1185
resource_manager : Optional [ResourceManager ] = None ) -> SampleState :
1186
+ # NB: The sampler is either called directly by PyExecutor, for the target model,
1187
+ # or by ModelDrafter.prepare_draft_tokens(), for the draft model. In the former
1188
+ # case there are 1 + get_draft_token_length(request) tokens per request. In the
1189
+ # latter case, there is always only 1 token per request because draft
1190
+ # tokens are sampled one-by-one.
1191
+
1186
1192
requests = scheduled_requests .all_requests ()
1187
1193
new_tokens = self .store .new_tokens
1188
1194
log_probs_host = self .log_probs_host (scheduled_requests )
@@ -1332,8 +1338,6 @@ def _sample_batched_by_strategy(
1332
1338
requests , pin_memory = True )
1333
1339
generator_cuda = self .get_generator (cuda_device )
1334
1340
1335
- # FIXME: This check should/could be performed in ModelDrafter.prepare_draft_tokens
1336
- #
1337
1341
# NB: Currently, "d2t" is applied to draft tokens, but not to draft logits,
1338
1342
# breaking _process_draft_tokens_rejection_sampling.
1339
1343
needs_d2t = "d2t" in model_outputs
@@ -1459,15 +1463,16 @@ def _sample_batched_by_strategy(
1459
1463
(batch_req_indices , batch_next_tokens_cuda_int ,
1460
1464
batch_softmax_cuda ), = batched_results
1461
1465
1462
- # FIXME: This should be done in ModelDrafter.prepare_draft_tokens, but for performance
1463
- # parity py_draft_tokens might need to be replaced / backed by a torch.Tensor, so
1464
- # that d2t can be applied in a batched manner similar to the code below.
1466
+ # NB: 'd2t' contains offsets for transforming draft vocab token IDs into
1467
+ # the target vocab. This is used by Eagle3ForCausalLM, whose input domain
1468
+ # is the target vocab, whereas the output logits correspond to the draft
1469
+ # vocab. Since the inputs/outputs are linked by TorchSampler.update_requests,
1470
+ # they currently need to be handled within TorchSampler. Changing the model
1471
+ # outputs to use the target vocab would require inflating the logit tensors,
1472
+ # which is inefficient. Changing the inputs to use the draft vocab, might
1473
+ # be cleaner, but would require applying 'd2t' in multiple locations:
1474
+ # Prefill, Eagle3ForCausalLM embeddings, ModelDrafter
1465
1475
if needs_d2t :
1466
- # NB: The sampler is either called directly by PyExecutor, for the target model,
1467
- # or by ModelDrafter.prepare_draft_tokens(), for the draft model. In the former
1468
- # case there are 1 + get_draft_token_length(request) tokens per request. In the
1469
- # latter case, only there is always only 1 token per request because draft
1470
- # tokens are sampled one-by-one.
1471
1476
self ._apply_d2t (batch_next_tokens_cuda_int , model_outputs )
1472
1477
1473
1478
return _BatchedSamplingResult (
@@ -1909,7 +1914,6 @@ def sample_async(
1909
1914
num_context_logits_prefix_sum : list [int ],
1910
1915
resource_manager : Optional [ResourceManager ] = None
1911
1916
) -> SampleStateTRTLLM :
1912
-
1913
1917
batch_size = scheduled_requests .batch_size
1914
1918
beam_width = self .beam_width (scheduled_requests .all_requests ())
1915
1919
if (batch_size > 1 and beam_width > 1
0 commit comments