Skip to content

Commit 964399e

Browse files
committed
Resolve conflicts with constrained decoding
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent ed6f81f commit 964399e

File tree

2 files changed

+15
-30
lines changed

2 files changed

+15
-30
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1983,8 +1983,7 @@ def _handle_speculative_decoding(self, scheduled_batch, previous_tensors):
19831983

19841984
target_inputs, draft_outputs, draft_batch = self.drafter.generate_draft_tokens_with_overlap(
19851985
scheduled_batch, self.resource_manager,
1986-
previous_tensors.device if previous_tensors else None,
1987-
self.guided_decoder)
1986+
previous_tensors.device if previous_tensors else None)
19881987

19891988
self.has_previous_draft_tokens = target_inputs is not None and target_inputs.next_draft_tokens is not None
19901989
else:

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -443,22 +443,17 @@ def _update_target_inputs_with_draft_tokens(
443443
draft_length] = draft_tensors[0:draft_length, idx]
444444

445445
def _setup_draft_batch_and_resources(
446-
self,
447-
scheduled_batch: ScheduledRequests,
448-
guided_decoder: Optional[GuidedDecoder] = None
446+
self, scheduled_batch: ScheduledRequests
449447
) -> Tuple[Optional[ScheduledRequests], Optional[Dict[int, LlmRequest]]]:
450448
"""
451449
Setup draft batch and prepare resources.
452450
453451
Args:
454452
scheduled_batch: The scheduled requests
455-
guided_decoder: The guided decoder
456453
457454
Returns:
458455
Tuple of (draft_batch, req_id_to_old_request) or (None, None) if no batch
459456
"""
460-
if guided_decoder is not None:
461-
guided_decoder.rollback_rejected_tokens(scheduled_batch)
462457

463458
draft_batch = self._prepare_draft_batch(scheduled_batch)
464459
if draft_batch.batch_size == 0:
@@ -510,11 +505,8 @@ def process_dynamic_draft_outputs(
510505
req_id_to_old_request)
511506

512507
def _execute_draft_iteration(
513-
self,
514-
draft_batch: ScheduledRequests,
515-
resource_manager: ResourceManager,
516-
previous_draft_state: Optional[SampleState],
517-
guided_decoder: Optional[GuidedDecoder] = None
508+
self, draft_batch: ScheduledRequests, resource_manager: ResourceManager,
509+
previous_draft_state: Optional[SampleState]
518510
) -> Tuple[Any, Optional[SampleState]]:
519511
"""Forward pass through the draft model."""
520512
outputs = self.forward_draft_model(
@@ -527,11 +519,10 @@ def _execute_draft_iteration(
527519
if previous_draft_state is not None:
528520
self.update_requests(previous_draft_state)
529521

530-
if guided_decoder is not None:
531-
guided_decoder.add_batch(draft_batch)
532-
guided_decoder.execute(outputs['logits'],
533-
outputs['logits'],
534-
d2t=outputs.get('d2t'))
522+
if self.guided_decoder is not None:
523+
self.guided_decoder.add_batch(draft_batch)
524+
self.guided_decoder.execute(outputs['logits'],
525+
d2t=outputs.get('d2t'))
535526

536527
sample_state = self.sample_async(draft_batch, outputs)
537528
self.update_request_states(draft_batch)
@@ -543,7 +534,6 @@ def _execute_draft_loop(
543534
draft_batch: ScheduledRequests,
544535
resource_manager: ResourceManager,
545536
req_id_to_old_request: Dict[int, LlmRequest],
546-
guided_decoder: Optional[GuidedDecoder] = None,
547537
target_inputs: Optional[SampleStateTensorsMTP] = None,
548538
num_draft_reqs: Optional[int] = None,
549539
initial_draft_state: Optional[SampleState] = None
@@ -555,7 +545,6 @@ def _execute_draft_loop(
555545
draft_batch: The draft batch to process
556546
resource_manager: The resource manager
557547
req_id_to_old_request: Mapping from request ID to original request
558-
guided_decoder: The guided decoder
559548
target_inputs: Optional target inputs to update (for overlap mode)
560549
num_draft_reqs: Number of draft requests (for overlap mode)
561550
initial_draft_state: The initial draft state from the first forward pass
@@ -575,8 +564,7 @@ def _execute_draft_loop(
575564
break
576565

577566
_, sample_state = self._execute_draft_iteration(
578-
draft_batch, resource_manager, previous_draft_state,
579-
guided_decoder)
567+
draft_batch, resource_manager, previous_draft_state)
580568

581569
# Update target inputs if provided (for overlap mode)
582570
if target_inputs is not None and num_draft_reqs is not None:
@@ -603,8 +591,7 @@ def _execute_draft_loop(
603591
def generate_draft_tokens_with_overlap(
604592
self, scheduled_batch: ScheduledRequests,
605593
resource_manager: ResourceManager,
606-
previous_tensors: Optional[SampleStateTensors],
607-
guided_decoder: Optional[GuidedDecoder]
594+
previous_tensors: Optional[SampleStateTensors]
608595
) -> Tuple[Optional[SampleStateTensorsMTP], Optional[Any],
609596
Optional[ScheduledRequests]]:
610597
"""
@@ -622,7 +609,7 @@ def generate_draft_tokens_with_overlap(
622609
- Draft sample state or None
623610
"""
624611
draft_batch, req_id_to_old_request = self._setup_draft_batch_and_resources(
625-
scheduled_batch, guided_decoder)
612+
scheduled_batch)
626613
if draft_batch is None:
627614
return None, None, None
628615

@@ -652,7 +639,6 @@ def generate_draft_tokens_with_overlap(
652639
if self.guided_decoder is not None:
653640
self.guided_decoder.add_batch(draft_batch)
654641
self.guided_decoder.execute(outputs['logits'],
655-
outputs['logits'],
656642
d2t=outputs.get('d2t'))
657643
draft_sample_state = self.sample_async(draft_batch, outputs)
658644

@@ -669,8 +655,8 @@ def generate_draft_tokens_with_overlap(
669655

670656
# Execute the iterative draft loop
671657
previous_draft_state = self._execute_draft_loop(
672-
draft_batch, resource_manager, req_id_to_old_request,
673-
guided_decoder, target_inputs, num_draft_reqs, draft_sample_state)
658+
draft_batch, resource_manager, req_id_to_old_request, target_inputs,
659+
num_draft_reqs, draft_sample_state)
674660

675661
return target_inputs, previous_draft_state, draft_batch
676662

@@ -719,8 +705,8 @@ def prepare_draft_tokens(
719705

720706
# Execute the iterative draft loop
721707
previous_draft_state = self._execute_draft_loop(
722-
draft_batch, resource_manager, req_id_to_old_request,
723-
self.guided_decoder, None, None, sample_state)
708+
draft_batch, resource_manager, req_id_to_old_request, None,
709+
None, sample_state)
724710

725711
# Final cleanup
726712
if previous_draft_state is not None:

0 commit comments

Comments
 (0)