diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 18edf8946c8..dab2af134b4 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -1068,7 +1068,8 @@ def _tree_sampling_batch(self, requests: list[LlmRequest], return new_draft_tokens_host def _process_draft_tokens_rejection_sampling( - self, request: LlmRequest, new_tokens: torch.Tensor) -> int: + self, request: LlmRequest, new_tokens: list[list[list[int]]], + new_tokens_tensor: torch.Tensor) -> int: sampling_strategy = _request_strategy(request) generator = self.get_generator(request.py_draft_logits.device) _, draft_probs = sample(sampling_strategy, @@ -1089,8 +1090,8 @@ def _process_draft_tokens_rejection_sampling( num_accepted = num_initially_accepted for i in range(num_accepted): new_token = request.py_draft_tokens[i] - new_tokens[i, request.seq_slot, self.BEAM] = new_token - new_token = add_token(request, new_tokens, beam=self.BEAM, step=i) + new_tokens_tensor[i, request.seq_slot, self.BEAM] = new_token + request.add_new_token(new_token, self.BEAM) stop = self._handle_stop_criteria(request, new_token) if stop: num_accepted = i + 1 @@ -1098,11 +1099,14 @@ def _process_draft_tokens_rejection_sampling( if sample_last: new_token = sample_rejected(draft_probs, target_probs, generator, num_accepted) - new_tokens[num_accepted, request.seq_slot, self.BEAM] = new_token - new_token = add_token(request, - new_tokens, - beam=self.BEAM, - step=num_accepted) + new_tokens_tensor[num_accepted, request.seq_slot, + self.BEAM] = new_token + request.add_new_token(new_token, self.BEAM) + else: + new_token = add_token(request, + new_tokens, + beam=self.BEAM, + step=num_accepted) stop = self._handle_stop_criteria(request, new_token) return num_accepted @@ -1110,7 +1114,8 @@ def _process_draft_tokens_rejection_sampling( def process_draft_tokens( self, request: LlmRequest, - new_tokens: torch.Tensor, + new_tokens: list[list[list[int]]], + new_tokens_tensor: torch.Tensor, resource_manager: Optional[ResourceManager] = None) -> int: if request.py_draft_logits is None: spec_tree_manager = self.get_spec_tree_manager(resource_manager) @@ -1125,7 +1130,7 @@ def process_draft_tokens( return num_accepted else: return self._process_draft_tokens_rejection_sampling( - request, new_tokens) + request, new_tokens, new_tokens_tensor) @override def update_requests( @@ -1151,6 +1156,7 @@ def update_requests( continue processed = 1 num_accepted = self.process_draft_tokens(req, new_tokens, + state.host.new_tokens, resource_manager) if get_draft_token_length(req) > 0: req.py_num_accepted_draft_tokens = num_accepted