@@ -221,11 +221,11 @@ class Args:
221221 max_draft_tokens : int
222222 max_num_sequences : int
223223 max_beam_width : int
224- mixed_sampler : bool
224+ enable_mixed_sampler : bool
225225
226226 def __init__ (self , args : Args ):
227227 self .max_seq_len = args .max_seq_len
228- self .mixed_sampler = args .mixed_sampler
228+ self .enable_mixed_sampler = args .enable_mixed_sampler
229229 self .max_tokens = args .max_draft_tokens + 1
230230 assert args .max_beam_width == self .MAX_BEAM_WIDTH , "TorchSampler only supports beam_width = 1"
231231 self .num_seq_slots = args .max_num_sequences
@@ -402,7 +402,7 @@ def _process_requests(self,
402402 num_steps = [1 + len (req .py_draft_tokens ) for req in requests ]
403403 sum_steps = sum (num_steps )
404404 no_draft_tokens = len (requests ) == sum_steps
405- fast_path = not self .mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None
405+ fast_path = not self .enable_mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None
406406
407407 seq_slots = torch .as_tensor ([r .seq_slot for r in requests ])
408408 seq_slots = seq_slots .to (device = "cuda" , non_blocking = True )
@@ -419,7 +419,7 @@ def _process_requests(self,
419419 strategies = sampling_strategies (requests )
420420 batched_next_tokens , batched_softmax = None , None
421421 batched_strategy : Strategy | None = GREEDY
422- if self .mixed_sampler :
422+ if self .enable_mixed_sampler :
423423 assert "d2t" not in model_outputs , "eagle3 does not yet support non-greedy sampling"
424424 if len (set (strategies )) == 1 :
425425 batched_strategy = strategies [0 ]
0 commit comments