@@ -443,22 +443,17 @@ def _update_target_inputs_with_draft_tokens(
443443 draft_length ] = draft_tensors [0 :draft_length , idx ]
444444
445445 def _setup_draft_batch_and_resources (
446- self ,
447- scheduled_batch : ScheduledRequests ,
448- guided_decoder : Optional [GuidedDecoder ] = None
446+ self , scheduled_batch : ScheduledRequests
449447 ) -> Tuple [Optional [ScheduledRequests ], Optional [Dict [int , LlmRequest ]]]:
450448 """
451449 Setup draft batch and prepare resources.
452450
453451 Args:
454452 scheduled_batch: The scheduled requests
455- guided_decoder: The guided decoder
456453
457454 Returns:
458455 Tuple of (draft_batch, req_id_to_old_request) or (None, None) if no batch
459456 """
460- if guided_decoder is not None :
461- guided_decoder .rollback_rejected_tokens (scheduled_batch )
462457
463458 draft_batch = self ._prepare_draft_batch (scheduled_batch )
464459 if draft_batch .batch_size == 0 :
@@ -510,11 +505,8 @@ def process_dynamic_draft_outputs(
510505 req_id_to_old_request )
511506
512507 def _execute_draft_iteration (
513- self ,
514- draft_batch : ScheduledRequests ,
515- resource_manager : ResourceManager ,
516- previous_draft_state : Optional [SampleState ],
517- guided_decoder : Optional [GuidedDecoder ] = None
508+ self , draft_batch : ScheduledRequests , resource_manager : ResourceManager ,
509+ previous_draft_state : Optional [SampleState ]
518510 ) -> Tuple [Any , Optional [SampleState ]]:
519511 """Forward pass through the draft model."""
520512 outputs = self .forward_draft_model (
@@ -527,11 +519,10 @@ def _execute_draft_iteration(
527519 if previous_draft_state is not None :
528520 self .update_requests (previous_draft_state )
529521
530- if guided_decoder is not None :
531- guided_decoder .add_batch (draft_batch )
532- guided_decoder .execute (outputs ['logits' ],
533- outputs ['logits' ],
534- d2t = outputs .get ('d2t' ))
522+ if self .guided_decoder is not None :
523+ self .guided_decoder .add_batch (draft_batch )
524+ self .guided_decoder .execute (outputs ['logits' ],
525+ d2t = outputs .get ('d2t' ))
535526
536527 sample_state = self .sample_async (draft_batch , outputs )
537528 self .update_request_states (draft_batch )
@@ -543,7 +534,6 @@ def _execute_draft_loop(
543534 draft_batch : ScheduledRequests ,
544535 resource_manager : ResourceManager ,
545536 req_id_to_old_request : Dict [int , LlmRequest ],
546- guided_decoder : Optional [GuidedDecoder ] = None ,
547537 target_inputs : Optional [SampleStateTensorsMTP ] = None ,
548538 num_draft_reqs : Optional [int ] = None ,
549539 initial_draft_state : Optional [SampleState ] = None
@@ -555,7 +545,6 @@ def _execute_draft_loop(
555545 draft_batch: The draft batch to process
556546 resource_manager: The resource manager
557547 req_id_to_old_request: Mapping from request ID to original request
558- guided_decoder: The guided decoder
559548 target_inputs: Optional target inputs to update (for overlap mode)
560549 num_draft_reqs: Number of draft requests (for overlap mode)
561550 initial_draft_state: The initial draft state from the first forward pass
@@ -575,8 +564,7 @@ def _execute_draft_loop(
575564 break
576565
577566 _ , sample_state = self ._execute_draft_iteration (
578- draft_batch , resource_manager , previous_draft_state ,
579- guided_decoder )
567+ draft_batch , resource_manager , previous_draft_state )
580568
581569 # Update target inputs if provided (for overlap mode)
582570 if target_inputs is not None and num_draft_reqs is not None :
@@ -603,8 +591,7 @@ def _execute_draft_loop(
603591 def generate_draft_tokens_with_overlap (
604592 self , scheduled_batch : ScheduledRequests ,
605593 resource_manager : ResourceManager ,
606- previous_tensors : Optional [SampleStateTensors ],
607- guided_decoder : Optional [GuidedDecoder ]
594+ previous_tensors : Optional [SampleStateTensors ]
608595 ) -> Tuple [Optional [SampleStateTensorsMTP ], Optional [Any ],
609596 Optional [ScheduledRequests ]]:
610597 """
@@ -622,7 +609,7 @@ def generate_draft_tokens_with_overlap(
622609 - Draft sample state or None
623610 """
624611 draft_batch , req_id_to_old_request = self ._setup_draft_batch_and_resources (
625- scheduled_batch , guided_decoder )
612+ scheduled_batch )
626613 if draft_batch is None :
627614 return None , None , None
628615
@@ -652,7 +639,6 @@ def generate_draft_tokens_with_overlap(
652639 if self .guided_decoder is not None :
653640 self .guided_decoder .add_batch (draft_batch )
654641 self .guided_decoder .execute (outputs ['logits' ],
655- outputs ['logits' ],
656642 d2t = outputs .get ('d2t' ))
657643 draft_sample_state = self .sample_async (draft_batch , outputs )
658644
@@ -669,8 +655,8 @@ def generate_draft_tokens_with_overlap(
669655
670656 # Execute the iterative draft loop
671657 previous_draft_state = self ._execute_draft_loop (
672- draft_batch , resource_manager , req_id_to_old_request ,
673- guided_decoder , target_inputs , num_draft_reqs , draft_sample_state )
658+ draft_batch , resource_manager , req_id_to_old_request , target_inputs ,
659+ num_draft_reqs , draft_sample_state )
674660
675661 return target_inputs , previous_draft_state , draft_batch
676662
@@ -719,8 +705,8 @@ def prepare_draft_tokens(
719705
720706 # Execute the iterative draft loop
721707 previous_draft_state = self ._execute_draft_loop (
722- draft_batch , resource_manager , req_id_to_old_request ,
723- self . guided_decoder , None , None , sample_state )
708+ draft_batch , resource_manager , req_id_to_old_request , None ,
709+ None , sample_state )
724710
725711 # Final cleanup
726712 if previous_draft_state is not None :
0 commit comments