From c4f18e6da25aba76aee48c78832e89247ea28673 Mon Sep 17 00:00:00 2001 From: Netanel Haber Date: Thu, 26 Jun 2025 07:56:56 +0000 Subject: [PATCH 1/4] feature: unify new_tokens format sample state to trtllm samper new_tokens format Signed-off-by: Netanel Haber --- .../_torch/auto_deploy/shim/ad_executor.py | 24 +- tensorrt_llm/_torch/pyexecutor/_util.py | 51 +- .../_torch/pyexecutor/guided_decoder.py | 8 +- tensorrt_llm/_torch/pyexecutor/llm_request.py | 2 + .../_torch/pyexecutor/model_engine.py | 125 ++--- tensorrt_llm/_torch/pyexecutor/py_executor.py | 95 ++-- tensorrt_llm/_torch/pyexecutor/sampler.py | 506 ++++++++---------- tensorrt_llm/_torch/pyexecutor/scheduler.py | 6 +- .../_torch/pyexecutor/seq_slot_manager.py | 8 +- tensorrt_llm/_torch/speculative/eagle3.py | 49 +- tensorrt_llm/_torch/speculative/mtp.py | 21 +- tensorrt_llm/_torch/speculative/utils.py | 25 +- 12 files changed, 418 insertions(+), 502 deletions(-) diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 840241641a2..a733f37b1b2 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -4,6 +4,7 @@ import torch from torch._prims_common import DeviceLikeType +from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager from tensorrt_llm._utils import nvtx_range from ...._utils import mpi_rank, mpi_world_size @@ -264,6 +265,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir: ad_config: _AutoDeployLlmArgs = executor_config.pytorch_backend_config max_batch_size = ad_config.max_batch_size + max_num_sequences = ad_config.max_batch_size * dist_mapping.pp_size max_seq_len = ad_config.max_seq_len attn_page_size = ad_config.attn_page_size max_num_tokens = ad_config.max_num_tokens @@ -294,7 +296,13 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir: max_seq_len=max_seq_len, max_batch_size=max_batch_size, ) - resource_manager = ResourceManager({ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager}) + seq_slot_manager = SeqSlotManager(max_num_sequences=max_batch_size * dist_mapping.pp_size) + resource_manager = ResourceManager( + { + ResourceManagerType.KV_CACHE_MANAGER: kv_cache_manager, + ResourceManagerType.SEQ_SLOT_MANAGER: seq_slot_manager, + } + ) resource_manager.resource_managers.move_to_end(ResourceManagerType.KV_CACHE_MANAGER, last=True) # scheduling @@ -305,7 +313,18 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir: scheduler = SimpleScheduler(capacitor_scheduler, mb_scheduler) # search sampler with speculative decoding - sampler = TorchSampler(max_seq_len=max_seq_len) + # TODO (lucaslie, fridah-nv): some models require mixed_sampler=True to have good outputs, see + # https://github.com/NVIDIA/TensorRT-LLM/issues/5254 + # We should expose mixed_sample to our build_and_run_ad script so we can configure this + # correctly for models as needed. + sampler_args = TorchSampler.Args( + max_seq_len=max_seq_len, + max_draft_tokens=max_draft_tokens, + max_num_sequences=max_num_sequences, + max_beam_width=executor_config.max_beam_width, + mixed_sampler=ad_config.mixed_sampler, + ) + sampler = TorchSampler(sampler_args) # creating the executor object py_executor = PyExecutor( @@ -314,6 +333,7 @@ def create_autodeploy_executor(executor_config: ExecutorConfig, checkpoint_dir: model_engine=engine, sampler=sampler, dist=mpi_dist, + max_num_sequences=max_num_sequences, disable_overlap_scheduler=ad_config.disable_overlap_scheduler, max_input_len=ad_config.max_input_len, max_batch_size=max_batch_size, diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 4000f39329f..6306afc1ccc 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -26,8 +26,7 @@ from .resource_manager import (KVCacheManager, MambaHybridCacheManager, PeftCacheManager, ResourceManager, ResourceManagerType) -from .sampler import (EarlyStopSampler, TorchSampler, TorchStarAttentionSampler, - TRTLLMSampler) +from .sampler import EarlyStopSampler, TorchSampler, TRTLLMSampler from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler, SimpleScheduler) from .seq_slot_manager import SeqSlotManager @@ -512,6 +511,7 @@ def create_py_executor_instance( model_engine=model_engine, sampler=sampler, dist=dist, + max_num_sequences=max_num_sequences, disable_overlap_scheduler=pytorch_backend_config. disable_overlap_scheduler, max_batch_size=executor_config.max_batch_size, @@ -523,31 +523,44 @@ def create_py_executor_instance( garbage_collection_gen0_threshold=garbage_collection_gen0_threshold) -def instantiate_sampler(model_engine: PyTorchModelEngine, +def create_torch_sampler_args(executor_config: ExecutorConfig, mapping: Mapping, + *, max_seq_len: int, mixed_sampler: bool): + max_num_sequences = executor_config.max_batch_size * mapping.pp_size + max_draft_tokens = (0 if executor_config.speculative_config is None else + executor_config.speculative_config.max_draft_tokens) + return TorchSampler.Args( + max_seq_len=max_seq_len, + max_draft_tokens=max_draft_tokens, + max_num_sequences=max_num_sequences, + max_beam_width=executor_config.max_beam_width, + mixed_sampler=mixed_sampler, + ) + + +def instantiate_sampler(engine: PyTorchModelEngine, executor_config: ExecutorConfig, pytorch_backend_config: PyTorchConfig, mapping: Mapping): + sampler_args = create_torch_sampler_args( + executor_config, + mapping, + max_seq_len=engine.max_seq_len, + mixed_sampler=pytorch_backend_config.mixed_sampler) if mapping.cp_config.get('cp_type') == 'star_attention': assert pytorch_backend_config.attn_backend == "FLASHINFER_STAR_ATTENTION", "attention backend of star attention should be 'FLASHINFER_STAR_ATTENTION'" - sampler = TorchStarAttentionSampler( - max_seq_len=model_engine.max_seq_len) - elif model_engine.spec_config is not None and model_engine.spec_config.spec_dec_mode.has_spec_decoder( + return TorchSampler(sampler_args) + if engine.spec_config is not None and engine.spec_config.spec_dec_mode.has_spec_decoder( ): - sampler = get_spec_decoder(max_seq_len=model_engine.max_seq_len, - spec_config=model_engine.spec_config) - elif pytorch_backend_config.enable_trtllm_sampler: + return get_spec_decoder(sampler_args, engine.spec_config) + if pytorch_backend_config.enable_trtllm_sampler: decoding_mode = get_decoding_mode(executor_config) - sampler = TRTLLMSampler( - executor_config, model_engine.model, model_engine.dtype, mapping, - decoding_mode, pytorch_backend_config.disable_overlap_scheduler) - elif not model_engine.model.model_config.is_generation: + return TRTLLMSampler(executor_config, engine.model, engine.dtype, + mapping, decoding_mode, + pytorch_backend_config.disable_overlap_scheduler) + if not engine.model.model_config.is_generation: # NOTE: choose sampler based on model type - sampler = EarlyStopSampler() - else: - sampler = TorchSampler( - max_seq_len=model_engine.max_seq_len, - mixed_sampler=pytorch_backend_config.mixed_sampler) - return sampler + return EarlyStopSampler() + return TorchSampler(sampler_args) def get_decoding_mode(executor_config: ExecutorConfig) -> DecodingMode: diff --git a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py index fc21a2096e2..756c177a6ea 100644 --- a/tensorrt_llm/_torch/pyexecutor/guided_decoder.py +++ b/tensorrt_llm/_torch/pyexecutor/guided_decoder.py @@ -1,4 +1,3 @@ -import itertools import math from typing import List, Optional @@ -52,8 +51,7 @@ def bitmask_size(self) -> int: def build(self, scheduled_requests: ScheduledRequests, resource_manager: SeqSlotManager) -> None: - for llm_req in itertools.chain(scheduled_requests.context_requests, - scheduled_requests.generation_requests): + for llm_req in scheduled_requests.all_requests(): if llm_req.guided_decoding_params is None: continue slot = resource_manager.slot_manager.get_slot(llm_req.request_id) @@ -84,9 +82,7 @@ def execute(self, scheduled_requests: ScheduledRequests, torch.cuda.current_stream().wait_stream(self._stream) batched_logits, batched_bitmask = [], [] - for i, llm_req in enumerate( - itertools.chain(scheduled_requests.context_requests, - scheduled_requests.generation_requests)): + for i, llm_req in enumerate(scheduled_requests.all_requests()): if llm_req.guided_decoding_params is None: continue if llm_req.is_context_init_state and not llm_req.is_last_context_chunk: diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index 01e9324e987..f16e4e2dcfa 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -253,6 +253,7 @@ def __init__( return_logits_device_memory: bool = True, exclude_last_generation_logits: bool = False, stop_words_list: list[list[int]] | None = None, + is_draft: bool = False, **kwargs): self.py_logits_post_processors = kwargs.pop("py_logits_post_processors", None) @@ -286,6 +287,7 @@ def __init__( self.py_return_context_logits = return_context_logits self.py_return_generation_logits = return_generation_logits self.py_return_logits_device_memory = return_logits_device_memory + self.py_is_draft = is_draft # TODO: remove this when use DynamicDecodeOp in pytorch flow. # currently, keep py_stop_words_list as python list, rather than tensor. diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index aa0484867c1..c58b4ca266e 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -4,7 +4,6 @@ import gc import glob import inspect -import itertools import math import multiprocessing import os @@ -21,6 +20,7 @@ import torch._dynamo.config import tensorrt_llm.bindings.internal.userbuffers as ub +from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest from tensorrt_llm._torch.pyexecutor.sampler import SampleStateTensors from tensorrt_llm._torch.speculative.mtp import SampleStateTensorsMTP from tensorrt_llm._utils import (is_trace_enabled, local_mpi_rank, @@ -319,6 +319,7 @@ def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int], class PyTorchModelEngine(ModelEngine): + BEAM_WIDTH = 1 def __init__( self, @@ -659,13 +660,12 @@ def get_autotune_warmup_request(): return result @contextlib.contextmanager - def release_batch(result): + def release_batch(result: ScheduledRequests | None): try: yield result finally: if result is not None: - for req in itertools.chain(result.generation_requests, - result.context_requests): + for req in result.all_requests(): kv_cache_manager.free_resources(req) if spec_resource_manager is not None: spec_resource_manager.free_resources(req) @@ -1153,7 +1153,15 @@ def _prepare_tp_inputs( draft_lens = [] mrope_config = defaultdict(list) - batch_idx = 0 + mtp_batch_idx = 0 # Temporary: MTP (and Eagle3OneModel) remain the only samplers to index new_tokens serially + + def py_batch_idx(request: LlmRequest) -> int: + if not self.without_logits: + return request.seq_slot + nonlocal mtp_batch_idx + batch_idx = mtp_batch_idx + mtp_batch_idx += 1 + return batch_idx for request in scheduled_requests.context_requests: request_ids.append(request.py_request_id) @@ -1184,10 +1192,9 @@ def _prepare_tp_inputs( ) if mrope_rotary_cos_sin.device == 'cpu' else mrope_rotary_cos_sin mrope_config['mrope_rotary_cos_sin'].append( mrope_rotary_cos_sin.to('cuda', non_blocking=True)) - request.py_batch_idx = batch_idx - batch_idx += 1 + request.py_batch_idx = py_batch_idx(request) - num_ctx_requests = batch_idx + num_ctx_requests = len(scheduled_requests.context_requests) num_ctx_tokens = len(input_ids) new_tokens_device, new_tokens_lens_device, next_draft_tokens_device = None, None, None if new_tensors_device is not None: @@ -1227,7 +1234,7 @@ def _prepare_tp_inputs( assert spec_dec_mode.support_overlap_scheduler( ), f"{self.spec_config.spec_dec_name} does not support overlap scheduler" - # will contain previous batch incices of generation requests + # will contain previous batch indices of generation requests previous_batch_indices = [] previous_pos_indices = [] for request in extend_requests: @@ -1267,13 +1274,11 @@ def _prepare_tp_inputs( num_cached_tokens_per_seq.append(past_seen_token_num) request_ids.append(request.py_request_id) # update batch index - request.py_batch_idx = batch_idx - batch_idx += 1 + request.py_batch_idx = py_batch_idx(request) else: # update batch index previous_batch_idx = request.py_batch_idx - request.py_batch_idx = batch_idx - batch_idx += 1 + request.py_batch_idx = py_batch_idx(request) # inputs # overlap scheduler can only support the speculative decoding # methods with a fixed number of draft tokens @@ -1324,12 +1329,21 @@ def _prepare_tp_inputs( prompt_lengths.append(request.py_prompt_len) draft_lens.append(0) - request.py_batch_idx = batch_idx - batch_idx += 1 + request.py_batch_idx = py_batch_idx(request) + + previous_batch_len = len(previous_batch_indices) + + def previous_seq_slots_device(): + previous_batch_indices_host = torch.tensor(previous_batch_indices, + dtype=torch.int, + pin_memory=True) + previous_slots = self.previous_batch_indices_cuda[: + previous_batch_len] + previous_slots.copy_(previous_batch_indices_host, non_blocking=True) + return previous_slots num_tokens = len(input_ids) num_draft_tokens = len(draft_tokens) - previous_batchs = len(previous_batch_indices) num_requests = len(request_ids) total_num_tokens = len(position_ids) assert total_num_tokens <= self.max_num_tokens, ( @@ -1347,67 +1361,55 @@ def _prepare_tp_inputs( self.draft_tokens_cuda[:len(draft_tokens)].copy_(draft_tokens, non_blocking=True) if next_draft_tokens_device is not None: - if len(previous_batch_indices) > 0: - previous_batch_indices = torch.tensor(previous_batch_indices, - dtype=torch.int, - pin_memory=True) - self.previous_batch_indices_cuda[:previous_batchs].copy_( - previous_batch_indices, non_blocking=True) + if previous_batch_len > 0: + previous_slots = previous_seq_slots_device() # previous input ids - previous_batch_tokens = previous_batchs * (1 + - self.max_draft_len) - self.input_ids_cuda[ - num_tokens:num_tokens + - previous_batch_tokens].copy_(new_tokens_device[ - self.previous_batch_indices_cuda[:previous_batchs], :]. - flatten(), - non_blocking=True) + previous_batch_tokens = previous_batch_len * ( + 1 + self.max_draft_len) + new_tokens = new_tokens_device[previous_slots, :].flatten() + self.input_ids_cuda[num_tokens:num_tokens + + previous_batch_tokens].copy_( + new_tokens, non_blocking=True) # previous draft tokens - previous_batch_draft_tokens = previous_batchs * self.max_draft_len - self.draft_tokens_cuda[ - num_draft_tokens:num_draft_tokens + - previous_batch_draft_tokens].copy_(next_draft_tokens_device[ - self.previous_batch_indices_cuda[:previous_batchs], :]. - flatten(), - non_blocking=True) + previous_batch_draft_tokens = previous_batch_len * self.max_draft_len + self.draft_tokens_cuda[num_draft_tokens:num_draft_tokens + + previous_batch_draft_tokens].copy_( + next_draft_tokens_device[ + previous_slots, :].flatten(), + non_blocking=True) # prepare data for the preprocess inputs kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1 - previous_pos_indices = torch.tensor(previous_pos_indices, - dtype=torch.int, - pin_memory=True) + previous_pos_indices_host = torch.tensor(previous_pos_indices, + dtype=torch.int, + pin_memory=True) self.previous_pos_indices_cuda[0:previous_batch_tokens].copy_( - previous_pos_indices, non_blocking=True) + previous_pos_indices_host, non_blocking=True) self.previous_pos_id_offsets_cuda[ 0:previous_batch_tokens].copy_( new_tokens_lens_device[self.previous_pos_indices_cuda[ 0:previous_batch_tokens]], non_blocking=True) - self.previous_kv_lens_offsets_cuda[0:previous_batchs].copy_( - kv_len_offsets_device[ - self.previous_batch_indices_cuda[:previous_batchs]], - non_blocking=True) + self.previous_kv_lens_offsets_cuda[0:previous_batch_len].copy_( + kv_len_offsets_device[previous_slots], non_blocking=True) # for the requests that do not have previous batch, set the previous_pos_id_offsets and # previous_kv_lens_offsets to zeros to skip the value changes in _preprocess_inputs self.previous_pos_id_offsets_cuda[ previous_batch_tokens:num_requests * (1 + self.max_draft_len)] *= 0 self.previous_kv_lens_offsets_cuda[ - previous_batchs:num_requests] *= 0 + previous_batch_len:num_requests] *= 0 else: # change the data to zeros to skip the value changes in _preprocess_inputs self.previous_pos_id_offsets_cuda *= 0 self.previous_kv_lens_offsets_cuda *= 0 elif new_tokens_device is not None: - previous_batch_tokens = len(previous_batch_indices) - previous_batch_indices = torch.tensor(previous_batch_indices, - dtype=torch.int, - pin_memory=True) - self.previous_batch_indices_cuda[:previous_batch_tokens].copy_( - previous_batch_indices, non_blocking=True) - self.input_ids_cuda[num_tokens:num_tokens + previous_batchs].copy_( - new_tokens_device[ - self.previous_batch_indices_cuda[:previous_batchs]], - non_blocking=True) + seq_slots_device = previous_seq_slots_device() + max_draft_len = max(draft_lens) + new_tokens = new_tokens_device[:max_draft_len + 1, + seq_slots_device, :self.BEAM_WIDTH] + self.input_ids_cuda[num_tokens:num_tokens + + previous_batch_len].copy_(new_tokens.flatten(), + non_blocking=True) position_ids = torch.tensor(position_ids, dtype=torch.int, @@ -1645,7 +1647,6 @@ def _prepare_star_attention_inputs(self, # for star attention, we need customized block ids block_ids_per_seq = [] num_cached_tokens_per_seq = [] - output_token_idx = 0 for request in scheduled_requests.context_requests: request_ids.append(request.py_request_id) prompt_lengths.append(request.py_prompt_len) @@ -1702,8 +1703,6 @@ def _prepare_star_attention_inputs(self, sequence_lengths.append(len(input_id)) block_ids_per_seq.extend([all_cache_indices]) num_cached_tokens_per_seq.append(past_seen_token_num) - request.output_token_idx = output_token_idx - output_token_idx += 1 num_contexts = len(sequence_lengths) for request in scheduled_requests.context_requests: ctx_iter = request.ctx_iters @@ -1743,8 +1742,6 @@ def _prepare_star_attention_inputs(self, sequence_lengths.append(len(input_id)) block_ids_per_seq.extend([all_cache_indices]) num_cached_tokens_per_seq.append(past_seen_token_num) - request.output_token_idx = output_token_idx - output_token_idx += 1 num_queries = len(sequence_lengths) - num_contexts # Requests with draft tokens are treated like extend requests. @@ -1802,8 +1799,6 @@ def _prepare_star_attention_inputs(self, position_ids.append(last_query_pos_id + request.gen_iters + 1) block_ids_per_seq.extend([all_cache_indices]) num_cached_tokens_per_seq.append(past_seen_token_num) - request.output_token_idx = output_token_idx - output_token_idx += 1 num_tokens = len(input_ids) assert num_tokens <= self.max_num_tokens, ( @@ -2171,9 +2166,7 @@ def _execute_logit_post_processors(self, num_ctx_req = len(scheduled_requests.context_requests) logits_tensor = outputs["logits"] - for idx, request in enumerate( - itertools.chain(scheduled_requests.context_requests, - scheduled_requests.generation_requests)): + for idx, request in enumerate(scheduled_requests.all_requests()): logits_processors = getattr(request, "py_logits_post_processors", None) if not logits_processors: diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index ed5889a793c..08af489b149 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -11,12 +11,12 @@ import weakref from collections import namedtuple from contextlib import contextmanager -from itertools import chain from typing import Dict, List, Optional, Tuple, Union import torch from tensorrt_llm._torch.pyexecutor.resource_manager import ResourceManagerType +from tensorrt_llm._torch.pyexecutor.seq_slot_manager import SeqSlotManager from tensorrt_llm._utils import (customized_gc_thresholds, global_mpi_rank, is_trace_enabled, nvtx_range, trace_func) from tensorrt_llm.bindings.executor import (DisServingRequestStats, @@ -35,7 +35,7 @@ LlmResponse, executor_request_to_llm_request) from .model_engine import ModelEngine from .sampler import Sampler, SampleState, SampleStateTensors, TorchSampler -from .scheduler import ScheduledRequests +from .scheduler import RequestScheduler, ScheduledRequests # Environment variable to specify iteration ranges for profiling start/stop. # Format: "start1-stop1,start2-stop2,..." or single iterations "iter1,iter2,..." @@ -162,10 +162,11 @@ class PyExecutor: def __init__(self, resource_manager, - scheduler, + scheduler: RequestScheduler, model_engine: ModelEngine, sampler: Sampler, dist: Distributed, + max_num_sequences: int, disable_overlap_scheduler: bool = False, max_input_len: int = 2048, max_batch_size: int = 8, @@ -268,11 +269,13 @@ def __init__(self, if is_trace_enabled("TLLM_TRACE_EXECUTOR_LOOP"): self.event_loop = trace_func(self.event_loop) - if self.draft_model_engine is not None and self.event_loop.__name__ != self._executor_loop.__name__: - raise NotImplementedError( - "Drafting is not supported for selected executor loop. " - "Please disable disagg/pipeline parallelism/overlap scheduler.") - + if self.draft_model_engine is not None: + if self.event_loop.__name__ != self._executor_loop.__name__: + raise NotImplementedError( + "Drafting is not supported for selected executor loop. " + "Please disable disagg/pipeline parallelism/overlap scheduler." + ) + self.draft_seq_slot_manager = SeqSlotManager(max_num_sequences) self.garbage_collection_gen0_threshold = garbage_collection_gen0_threshold self.worker_started = False @@ -754,7 +757,7 @@ def _executor_loop_pp(self): "cpu", non_blocking=True) sample_state = self._sample_async( scheduled_batch, batch_outputs) - sample_state.logits_host = logits_host + sample_state.host.logits = logits_host self._update_request_states(scheduled_batch) if self.enable_iter_perf_stats: @@ -785,7 +788,6 @@ def _executor_loop_pp(self): # Receive tokens from previous pp rank (w.r.t model forward direction) ( logits, - sample_state.log_probs, sample_state.host, ) = self.dist.recv_object( src=self.dist.prev_pp_rank, @@ -793,8 +795,9 @@ def _executor_loop_pp(self): ) if logits is not None: logits_host = torch.from_numpy(logits) - sample_state.logits_host = logits_host - sample_state.logits = logits_host.to(self.device_id) + sample_state.host.logits = logits_host + sample_state.device.logits = logits_host.to( + self.device_id) else: torch.cuda.nvtx.range_push("_handle_new_tokens_last_pp") sample_state.sampler_event.synchronize() @@ -804,16 +807,16 @@ def _executor_loop_pp(self): if not self.dist.is_second_last_pp_rank: if self.send_handles[prev_microbatch_id] is not None: self.send_handles[prev_microbatch_id].Wait() + needs_logits = ( + self._need_return_logits(scheduled_batch) + or (self._need_return_log_probs(scheduled_batch) + and sample_state.host.log_probs is not None)) + serialized_logits = sample_state.host.logits.numpy( + ) if needs_logits else None self.send_handles[ prev_microbatch_id] = self.dist.isend_object( ( - sample_state.logits_host.numpy() if - self._need_return_logits(scheduled_batch) or - (self._need_return_log_probs( - scheduled_batch) - and sample_state.log_probs is not None) - else None, - sample_state.log_probs, + serialized_logits, sample_state.host, ), dest=self.dist.next_pp_rank, @@ -1727,8 +1730,7 @@ def _insert_ngram_iter_stats( total_num_draft_tokens = 0 total_num_accepted_tokens = 0 num_requests_with_draft_tokens = 0 - for request in chain(scheduled_requests.context_requests, - scheduled_requests.generation_requests): + for request in scheduled_requests.all_requests(): num_draft_tokens = 0 if request.py_last_draft_tokens is None else len( request.py_last_draft_tokens) num_accepted_tokens = getattr(request, @@ -1799,38 +1801,33 @@ def _prepare_draft_batch( input_tokens = spec_config.get_draft_model_prompt( request.get_tokens()[beam_idx]) + def create_new_request(input_tokens): + return LlmRequest(request_id=request.py_request_id, + max_new_tokens=request.py_max_new_tokens, + input_tokens=input_tokens, + sampling_config=request.sampling_config, + is_streaming=False, + is_draft=True) + if request.max_beam_num_tokens - 1 == request.py_prompt_len: # This is the first time the draft model is seeing this request. # Prepare a context request. We discard the first token and take # the newly decoded one - this is the convention for EAGLE 2 and 3. assert num_draft_tokens == 0 - new_request = LlmRequest( - request_id=request.py_request_id, - max_new_tokens=request.py_max_new_tokens, - input_tokens=input_tokens, - sampling_config=request.sampling_config, - is_streaming=False) - + new_request = create_new_request(input_tokens) draft_batch.context_requests.append(new_request) elif num_accepted_tokens == 0: - new_request = LlmRequest( - request_id=request.py_request_id, - max_new_tokens=request.py_max_new_tokens, - input_tokens=input_tokens[:-1], - sampling_config=request.sampling_config, - is_streaming=False) + new_request = create_new_request(input_tokens[:-1]) # Explicitly add the last token so get_last_tokens() returns # the right value new_request.add_new_token(input_tokens[-1], beam_idx) new_request.state = LlmRequestState.GENERATION_IN_PROGRESS draft_batch.generation_requests.append(new_request) else: - new_request = LlmRequest( - request_id=request.py_request_id, - max_new_tokens=request.py_max_new_tokens, - input_tokens=input_tokens, - sampling_config=request.sampling_config, - is_streaming=False) + new_request = create_new_request(input_tokens) + new_request.context_chunk_size = num_accepted_tokens + 1 + new_request.context_current_position = len( + input_tokens) - num_accepted_tokens - 1 new_request.context_chunk_size = num_accepted_tokens + 1 new_request.context_current_position = len( input_tokens) - num_accepted_tokens - 1 @@ -1849,16 +1846,19 @@ def _prepare_draft_batch( @nvtx_range("_prepare_draft_tokens") def _prepare_draft_tokens(self, scheduled_requests: ScheduledRequests): + if not self.draft_model_engine: + raise ValueError("Draft model engine is not set") + try: draft_batch = self._prepare_draft_batch(scheduled_requests) if draft_batch.batch_size == 0: return + self.draft_seq_slot_manager.prepare_resources(draft_batch) req_id_to_old_request = { req.py_request_id: req - for req in chain(scheduled_requests.context_requests, - scheduled_requests.generation_requests) + for req in scheduled_requests.all_requests() } # Disable cuda graph for the 1st draft model forward @@ -1880,8 +1880,7 @@ def _prepare_draft_tokens(self, scheduled_requests: ScheduledRequests): def _process_decoded_tokens(draft_batch): new_requests = [] - for req in chain(draft_batch.context_requests, - draft_batch.generation_requests): + for req in draft_batch.all_requests(): target_model_req = req_id_to_old_request[req.py_request_id] target_model_req.py_draft_tokens.append( req.get_last_tokens(0)) @@ -1889,6 +1888,8 @@ def _process_decoded_tokens(draft_batch): target_model_req.py_draft_tokens ) < target_model_req.py_draft_pages_allocated: new_requests.append(req) + else: + self.draft_seq_slot_manager.free_resources(req) return new_requests @@ -2126,14 +2127,12 @@ def _pause_requests(self, requests_to_pause): def _add_inflight_ids(self, scheduled_requests): """Add reqids of current requests to self.inflight_req_ids.""" - for req in chain(scheduled_requests.context_requests, - scheduled_requests.generation_requests): + for req in scheduled_requests.all_requests(): self.inflight_req_ids.insert(req.request_id) def _remove_inflight_ids(self, scheduled_requests): """Remove reqids of current requests from self.inflight_req_ids.""" - for req in chain(scheduled_requests.context_requests, - scheduled_requests.generation_requests): + for req in scheduled_requests.all_requests(): self.inflight_req_ids.erase(req.request_id) def _should_exclude_last_generation_logits(self) -> bool: diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index c96ad3356f9..885dd0c47b6 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -28,9 +28,11 @@ from .scheduler import ScheduledRequests -@dataclass(frozen=True, kw_only=True) +@dataclass(kw_only=True) class SampleStateTensors: new_tokens: torch.Tensor + logits: torch.Tensor | None = None + log_probs: torch.Tensor | None = None def values(self): return vars(self).values() @@ -40,13 +42,6 @@ def values(self): class SampleState: scheduled_requests: ScheduledRequests - logits: torch.Tensor = None - logits_host: torch.Tensor = None - - # Set when decode_async() has evaluated these to avoid computing again in update_requests() - # log_probs[request_idx][token_idx] - log_probs: list[list[float] | None] | None = None - device: SampleStateTensors = None host: SampleStateTensors = None @@ -78,10 +73,12 @@ class EarlyStopSampler(Sampler): def sample_async(self, scheduled_requests: ScheduledRequests, model_outputs) -> SampleState: - return SampleState(scheduled_requests=scheduled_requests, - logits=model_outputs['logits']) + host = SampleStateTensors(logits=model_outputs['logits'], + new_tokens=torch.empty(0)) + return SampleState(scheduled_requests=scheduled_requests, host=host) def update_requests(self, state: SampleState) -> None: + assert isinstance(state, SampleState) scheduled_requests = state.scheduled_requests assert (not scheduled_requests.generation_requests) for idx, request in enumerate(scheduled_requests.context_requests): @@ -89,7 +86,7 @@ def update_requests(self, state: SampleState) -> None: # NOTE: This is a hack: set finish reason manually and set the beam 0 request.set_finished_reason(FinishReason.LENGTH, 0) if request.py_return_context_logits: - logits = state.logits[idx] + logits = state.host.logits[idx] if logits.ndim == 1: # For BERT: Add axis to be compatible with LogitsStorage # (LogitsStorage will interpret this dim as the prompt_len which @@ -105,8 +102,6 @@ def top_k_sampling_batch(logits, top_k=50): # logits should be 2D :[batch_size, vocab_size] batch_size, vocab_size = logits.size() - raw_probs = torch.softmax(logits, dim=-1) - # get first top_k logits of each sample and their indices values, indices = torch.topk(logits, top_k, dim=-1) min_values = values[:, -1].unsqueeze(-1).expand(batch_size, vocab_size) @@ -116,24 +111,18 @@ def top_k_sampling_batch(logits, top_k=50): torch.full_like(logits, float('-inf')), logits) # compute probability distribution - probs = torch.softmax(logits, dim=-1) + softmax = torch.softmax(logits, dim=-1) # sample from the distribution and generate result of [batch_size, 1] - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1) - token_probs = torch.gather(raw_probs, dim=1, - index=next_tokens.unsqueeze(1)).squeeze(-1) - log_probs = torch.log(token_probs) - return next_tokens, log_probs + next_tokens = torch.multinomial(softmax, num_samples=1).squeeze(-1) + return next_tokens, softmax -def top_p_sampling_batch(logits, top_p=0.9): +def top_p_sampling_batch(logits: torch.Tensor, top_p: float = 0.9): logits_dim = logits.dim() if logits_dim == 1: logits = logits.unsqueeze(0) - # logits should be 2D :[batch_size, vocab_size] - batch_size, vocab_size = logits.size() - - raw_probs = torch.softmax(logits, dim=-1) + assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" # sort the logits of each sample in descending order sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1) @@ -153,46 +142,82 @@ def top_p_sampling_batch(logits, top_p=0.9): logits = logits.masked_fill(indices_to_remove, float('-inf')) # compute probability distribution - probs = torch.softmax(logits, dim=-1) + softmax = torch.softmax(logits, dim=-1) # sample from the distribution and generate result of [batch_size, 1] - next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1) - token_probs = torch.gather(raw_probs, dim=1, - index=next_tokens.unsqueeze(1)).squeeze(-1) - log_probs = torch.log(token_probs) - return next_tokens, log_probs + next_tokens = torch.multinomial(softmax, num_samples=1).squeeze(-1) + return next_tokens, softmax def greedy_search_sampling_batch(logits): - raw_probs = torch.softmax(logits, dim=-1) next_tokens = torch.argmax(logits, dim=-1) - token_probs = torch.gather(raw_probs, dim=1, - index=next_tokens.unsqueeze(1)).squeeze(-1) - log_probs = torch.log(token_probs) - return next_tokens, log_probs + softmax = torch.softmax(logits, dim=-1) + return next_tokens, softmax -def decode_single_request(request: LlmRequest, logits): +def sample_single_request(request: LlmRequest, logits: torch.Tensor): assert logits.dim( ) == 2 and logits.shape[0] == 1, "logits should have shape [1, vocab_size]" if request.sampling_config.top_p is not None and len( request.sampling_config.top_p) > 0: - next_tokens, log_probs = top_p_sampling_batch( - logits, request.sampling_config.top_p[0]) + return top_p_sampling_batch(logits, request.sampling_config.top_p[0]) elif request.sampling_config.top_k is not None and len( request.sampling_config.top_k) > 0: - next_tokens, log_probs = top_k_sampling_batch( - logits, request.sampling_config.top_k[0]) + return top_k_sampling_batch(logits, request.sampling_config.top_k[0]) else: - next_tokens, log_probs = greedy_search_sampling_batch(logits) - return next_tokens, log_probs + return greedy_search_sampling_batch(logits) -class TorchSampler(Sampler): +def new_tokens_slice(request: LlmRequest, beam: int, *, + size: int) -> tuple[slice, int, int]: + return slice(0, size), request.seq_slot, beam + - def __init__(self, max_seq_len: int, mixed_sampler: bool = False): - self.max_seq_len = max_seq_len - self.mixed_sampler = mixed_sampler +def add_token(request: LlmRequest, + new_tokens: torch.Tensor, + *, + beam: int, + step: int = 0) -> int: + seq_slot = request.seq_slot + assert seq_slot is not None + new_token = int(new_tokens[step, request.seq_slot, beam]) + request.add_new_token(new_token, beam) + return new_token + + +class TorchSampler(Sampler): + BEAM = 0 + MAX_BEAM_WIDTH = BEAM + 1 + + @dataclass(frozen=True, kw_only=True) + class Store: + new_tokens: torch.Tensor + """Shape: See cpp DecoderState.getAllNewTokens()""" + + @dataclass(frozen=True, kw_only=True) + class Args: + max_seq_len: int + max_draft_tokens: int + max_num_sequences: int + max_beam_width: int + mixed_sampler: bool + + def __init__(self, args: Args): + self.max_seq_len = args.max_seq_len + self.mixed_sampler = args.mixed_sampler + self.max_tokens = args.max_draft_tokens + 1 + assert args.max_beam_width == self.MAX_BEAM_WIDTH, "TorchSampler only supports beam_width = 1" + self.num_seq_slots = args.max_num_sequences + + # AutoDeploy build creates the sampler in inference mode, + # which would disallow in-place mutating of new_tokens. + # So, we temporarily exit inference mode. + with torch.inference_mode(False): + new_tokens = torch.zeros( + (self.max_tokens, self.num_seq_slots, self.MAX_BEAM_WIDTH), + dtype=torch.int, + device='cuda') + self.store = self.Store(new_tokens=new_tokens) def _meet_max_token_stop_criteria(self, request: LlmRequest, num_tokens: int): @@ -200,7 +225,8 @@ def _meet_max_token_stop_criteria(self, request: LlmRequest, >= request.py_max_new_tokens) or (num_tokens >= self.max_seq_len) - def _meet_stop_token_criteria(self, request: LlmRequest): + @staticmethod + def _meet_stop_token_criteria(request: LlmRequest): if request.py_stop_words_list: assert isinstance( request.py_stop_words_list, @@ -218,233 +244,163 @@ def _meet_stop_token_criteria(self, request: LlmRequest): return True return False - def _handle_stop_criteria(self, request: LlmRequest, new_token: int, - num_tokens: int, beam_idx: int) -> bool: + def _handle_stop_criteria(self, request: LlmRequest, new_token: int, *, + beam: int) -> bool: """Handle stop criteria and set appropriate finish reasons and state. Returns True if generation should stop.""" if new_token == request.py_end_id: - request.state = LlmRequestState.GENERATION_COMPLETE - request.set_finished_reason(FinishReason.END_ID, beam_idx) + request.finish_by_reason(FinishReason.END_ID) return True + num_tokens = request.get_num_tokens(beam) if self._meet_max_token_stop_criteria(request, num_tokens): - request.state = LlmRequestState.GENERATION_COMPLETE - request.set_finished_reason(FinishReason.LENGTH, beam_idx) + request.finish_by_reason(FinishReason.LENGTH) return True if self._meet_stop_token_criteria(request): - request.state = LlmRequestState.GENERATION_COMPLETE - request.set_finished_reason(FinishReason.STOP_WORDS, beam_idx) + request.finish_by_reason(FinishReason.STOP_WORDS) return True return False - def update_requests(self, state: SampleState) -> None: - if state.sampler_event: - state.sampler_event.synchronize() - new_tokens_list = state.host.new_tokens.tolist() - scheduled_requests = state.scheduled_requests - - request_idx = 0 - token_idx = 0 - beam_idx = 0 - - def advance_idx(num_tokens=1): - nonlocal request_idx, token_idx - request_idx += 1 - token_idx += num_tokens - - def handle_logits(request: LlmRequest, tokens: list[int], count=1): - if state.logits is None: - return - if not request.py_return_generation_logits and not request.py_return_log_probs: - return - - current_slice = slice(token_idx, token_idx + count) - current_logits = state.logits[current_slice] - + def handle_logits(self, request: LlmRequest, state: SampleState, *, + beam: int, count: int): + current_slice = new_tokens_slice(request, beam, size=count) + if request.py_return_generation_logits: + assert state.host.logits is not None + current_logits = state.host.logits[current_slice] request.py_result.append_generation_logits(current_logits) - - if not request.py_return_log_probs: - return - - if state.log_probs: - log_probs = state.log_probs[request_idx] - else: - _, log_probs = greedy_search_sampling_batch(current_logits) + if request.py_return_log_probs: + assert state.host.log_probs is not None + log_probs = state.host.log_probs[request.seq_slot][beam][:count] + current_tokens = state.host.new_tokens[current_slice] token_log_probs = [{ - token: Logprob(logprob=logprob, rank=1) - } for token, logprob in zip(tokens, log_probs.tolist())] + int(token): Logprob(logprob=logprob, rank=1) + } for token, logprob in zip(current_tokens, log_probs.tolist())] + assert beam == 0, "The following call relies on beam_width to be 1 - hence the list with a single element" request.py_result.append_log_probs([token_log_probs]) - if hasattr(scheduled_requests, 'chunked_requests'): - request_idx += len(scheduled_requests.chunked_requests) - token_idx += len(scheduled_requests.chunked_requests) - - for request in scheduled_requests.context_requests: - if request.context_remaining_length != 0: - advance_idx() - continue - - if request.state != LlmRequestState.GENERATION_COMPLETE: - new_token = new_tokens_list[token_idx] - num_tokens = request.add_new_token(new_token, beam_idx) - self._handle_stop_criteria(request, new_token, num_tokens, - beam_idx) - handle_logits(request, [new_token]) - request.py_decoding_iter += 1 - advance_idx() - - extend_requests = [] - generation_requests = [] - for request in scheduled_requests.generation_requests: - if len(request.py_draft_tokens) > 0: - extend_requests.append(request) - else: - generation_requests.append(request) + def process_draft_tokens(self, request: LlmRequest, + new_tokens: torch.Tensor, new_token: int) -> int: + num_accepted = 0 + for draft_token in request.py_draft_tokens: + if draft_token != new_token: + # Reject. + break + num_accepted += 1 + new_token = add_token(request, + new_tokens, + beam=self.BEAM, + step=num_accepted) + if self._handle_stop_criteria(request, new_token, beam=self.BEAM): + break + return num_accepted - for request in extend_requests: - if request.state != LlmRequestState.GENERATION_COMPLETE: - new_token = new_tokens_list[token_idx] - num_tokens = request.add_new_token(new_token, beam_idx) - if self._handle_stop_criteria(request, new_token, num_tokens, - beam_idx): - continue + def update_requests(self, state: SampleState) -> None: + assert isinstance(state, SampleState) + if state.sampler_event: + state.sampler_event.synchronize() + new_tokens = state.host.new_tokens - # Accept draft tokens (if we have any) if and only if they match the new - # token exactly. - num_accepted = 0 - new_tokens = [new_token] - for draft_token in request.py_draft_tokens: - if draft_token != new_token: - # Reject. - break - num_accepted += 1 - new_token = new_tokens_list[token_idx + num_accepted] - num_tokens = request.add_new_token(new_token, beam_idx) - new_tokens.append(num_tokens) # `num_tokens`->`new_token` - - if self._handle_stop_criteria(request, new_token, - num_tokens, beam_idx): - break - handle_logits(request, new_tokens, num_accepted) - request.py_decoding_iter += 1 - request.py_num_accepted_draft_tokens = num_accepted - request.py_rewind_len = request.py_draft_pages_allocated - num_accepted - advance_idx(len(request.py_draft_tokens) + 1) + for req in state.scheduled_requests.context_requests: + if req.state == LlmRequestState.GENERATION_COMPLETE or req.context_remaining_length != 0: + continue + new_token = add_token(req, new_tokens, beam=self.BEAM) + stop = self._handle_stop_criteria(req, new_token, beam=self.BEAM) + self.handle_logits(req, state, beam=self.BEAM, count=1) + req.py_decoding_iter += 1 - for request in generation_requests: - if request.state != LlmRequestState.GENERATION_COMPLETE: - new_token = new_tokens_list[token_idx] - num_tokens = request.add_new_token(new_token, beam_idx) - self._handle_stop_criteria(request, new_token, num_tokens, - beam_idx) - handle_logits(request, [new_token]) - request.py_decoding_iter += 1 - advance_idx() - - def _mixed_sample(self, scheduled_requests: ScheduledRequests, - model_outputs) -> SampleState: - logits = model_outputs["logits"] - log_probs = [] - new_tokens_device_array = [] - - idx = 0 - - for request in scheduled_requests.context_requests: - assert not request.py_return_context_logits, "Return context logits not supported" - token_logits = logits[idx:idx + 1, :] - new_token, probs = decode_single_request(request, token_logits) - new_tokens_device_array.append(new_token) - probs = [probs.tolist()] if request.py_return_log_probs else None - log_probs.append(probs) # Currently always beam_width=1 - idx += 1 - - for request in scheduled_requests.generation_requests: - if request.state == LlmRequestState.GENERATION_COMPLETE: + for req in state.scheduled_requests.generation_requests: + if req.state == LlmRequestState.GENERATION_COMPLETE: continue - assert len( - request.py_draft_tokens - ) == 0, "Speculative decoding not supported in SeparateDecoder." - token_logits = logits[idx:idx + 1, :] - new_token, probs = decode_single_request(request, token_logits) - new_tokens_device_array.append(new_token) - probs = [probs.tolist()] if request.py_return_log_probs else None - log_probs.append(probs) # Currently always beam_width=1 - idx += 1 - - new_tokens_device = torch.cat(new_tokens_device_array) - new_tokens_host = new_tokens_device.to('cpu', non_blocking=True) - sampler_event = torch.cuda.Event() - sampler_event.record() + new_token = add_token(req, new_tokens, beam=self.BEAM) + stop = self._handle_stop_criteria(req, new_token, beam=self.BEAM) + processed = 1 + if not stop and len(req.py_draft_tokens) > 0: + num_accepted = self.process_draft_tokens( + req, new_tokens, new_token) + req.py_num_accepted_draft_tokens = num_accepted + req.py_rewind_len = req.py_draft_pages_allocated - num_accepted + processed += num_accepted + self.handle_logits(req, state, beam=self.BEAM, count=processed) + req.py_decoding_iter += 1 + + def log_probs_host(self, requests: Iterable[LlmRequest]): + """Shape: In lockstep with TRTLLMSampler: https://github.com/NVIDIA/TensorRT-LLM/blob/cea5dd1e3883b18bf50901a7f196f50a9544c28c/cpp/include/tensorrt_llm/runtime/decoderState.h#L103""" + if any(req.py_return_log_probs for req in requests): + return torch.empty( + (self.num_seq_slots, self.MAX_BEAM_WIDTH, self.max_tokens), + device="cpu", + pin_memory=True) + return None + + def gen_logits_host(self, requests: Iterable[LlmRequest], vocab_size: int): + if any(req.py_return_generation_logits for req in requests): + return torch.empty((self.max_tokens, self.num_seq_slots, + self.MAX_BEAM_WIDTH, vocab_size), + device="cpu", + pin_memory=True) + return None - return SampleState( - scheduled_requests=scheduled_requests, - logits=logits, - device=SampleStateTensors(new_tokens=new_tokens_device), - host=SampleStateTensors(new_tokens=new_tokens_host), - sampler_event=sampler_event, - log_probs=log_probs) - - def _batch_sample(self, scheduled_requests: ScheduledRequests, - model_outputs) -> SampleState: - logits = model_outputs["logits"] - new_tokens_device = torch.argmax(logits, dim=-1) - new_tokens_host = new_tokens_device.to('cpu', non_blocking=True) + def sample_async(self, scheduled_requests: ScheduledRequests, + model_outputs: dict[str, torch.Tensor]) -> SampleState: + requests = scheduled_requests.all_requests() + new_tokens = self.store.new_tokens + vocab_size = model_outputs["logits"].shape[-1] + log_probs_host = self.log_probs_host(requests) + gen_logits_host = self.gen_logits_host(requests, vocab_size) + self._process_requests(requests, + model_outputs, + new_tokens, + gen_logits_host=gen_logits_host, + log_probs_host=log_probs_host) + new_tokens_host = new_tokens.to(device="cpu", non_blocking=True) sampler_event = torch.cuda.Event() sampler_event.record() - return SampleState( - scheduled_requests=scheduled_requests, - logits=logits, - device=SampleStateTensors(new_tokens=new_tokens_device), - host=SampleStateTensors(new_tokens=new_tokens_host), - sampler_event=sampler_event) - - def sample_async(self, scheduled_requests: ScheduledRequests, - model_outputs) -> SampleState: - if self.mixed_sampler: - return self._mixed_sample(scheduled_requests, model_outputs) - else: - return self._batch_sample(scheduled_requests, model_outputs) - - -class TorchStarAttentionSampler(TorchSampler): - - def update_one_request(self, request: LlmRequest, - new_tokens_list: list[int], logits: torch.Tensor): - beam_idx = 0 - - output_token_idx = request.output_token_idx - new_token = new_tokens_list[output_token_idx] - num_tokens = request.add_new_token(new_token, beam_idx) - - current_logits = logits[output_token_idx].unsqueeze(0) - if request.py_return_generation_logits: - request.py_result.append_generation_logits(current_logits) - if request.py_return_log_probs: - _, log_probs = greedy_search_sampling_batch(current_logits) - request.py_result.append_log_probs([[{ - new_token: - Logprob(logprob=log_probs.item(), rank=1) - }]]) - - self._handle_stop_criteria(request, new_token, num_tokens, beam_idx) - if request.state != LlmRequestState.GENERATION_COMPLETE: - request.py_decoding_iter += 1 - - def update_requests(self, state: SampleState): - if state.sampler_event: - state.sampler_event.synchronize() - new_tokens_list = state.host.new_tokens.tolist() - logits = state.logits - - for request in state.scheduled_requests.context_requests: - if request.state == LlmRequestState.GENERATION_IN_PROGRESS: - self.update_one_request(request, new_tokens_list, logits) - - for request in state.scheduled_requests.generation_requests: - self.update_one_request(request, new_tokens_list, logits) + return SampleState(scheduled_requests=scheduled_requests, + device=SampleStateTensors(new_tokens=new_tokens), + host=SampleStateTensors(new_tokens=new_tokens_host, + log_probs=log_probs_host, + logits=gen_logits_host), + sampler_event=sampler_event) + + def _process_requests(self, + requests: list[LlmRequest], + model_outputs: dict[str, torch.Tensor], + new_tokens: torch.Tensor, + *, + gen_logits_host: torch.Tensor | None = None, + log_probs_host: torch.Tensor | None = None): + beam = self.BEAM + offset = 0 + raw_logits = model_outputs["logits"] + + for request in requests: + steps = 1 + if len(request.py_draft_tokens) > 0: + assert not self.mixed_sampler, "Speculative decoding not supported in mixed sampler" + steps += len(request.py_draft_tokens) + logits = raw_logits[offset:offset + steps] + if self.mixed_sampler: + next_tokens, softmax = sample_single_request(request, logits) + else: + next_tokens, softmax = greedy_search_sampling_batch(logits) + current_slice = new_tokens_slice(request, beam, size=steps) + new_tokens[current_slice] = next_tokens + if "d2t" in model_outputs: # Eagle3 + new_tokens[current_slice] += model_outputs["d2t"][ + new_tokens[current_slice]] + if gen_logits_host is not None: + gen_logits_host[current_slice].copy_(logits, non_blocking=True) + if log_probs_host is not None: + assert beam == 0, "The following call relies on beam_width to be 1 - hence the unsqueeze" + token_probs = torch.gather( + softmax, dim=1, index=next_tokens.unsqueeze(1)).squeeze(-1) + log_probs = torch.log(token_probs) + log_probs_host[request.seq_slot, + beam, :steps].copy_(log_probs, non_blocking=True) + offset += steps class Algorithms: @@ -457,19 +413,17 @@ def __repr__(self): return f"Algs({', '.join(algs)})" -@dataclass(frozen=True, kw_only=True) +@dataclass(kw_only=True) class SampleStateTensorsHostTRTLLM(SampleStateTensors): finished_sum: torch.Tensor finish_reasons: torch.Tensor sequence_lengths: torch.Tensor - log_probs: torch.Tensor - cum_log_probs: torch.Tensor + cum_log_probs: torch.Tensor | None = None @dataclass(kw_only=True) class SampleStateTRTLLM(SampleState): host: SampleStateTensorsHostTRTLLM - device: SampleStateTensors class TRTLLMSampler(Sampler): @@ -532,13 +486,6 @@ def _initialize_store(self): DecoderInputBuffers(self.max_num_sequences, self.executor_config.max_batch_size, self.MAX_DECODING_TOKENS, buffer_manager), - "new_tokens_device_tensor": - torch.empty(( - self.executor_config.max_batch_size, - self.executor_config.max_beam_width, - ), - dtype=torch.int, - device='cuda'), "sequence_lengths_host": torch.empty(( self.executor_config.max_batch_size, @@ -605,7 +552,6 @@ def beam_width(scheduled_requests: Iterable[LlmRequest]) -> int: def sample_async(self, scheduled_requests: ScheduledRequests, model_outputs) -> SampleStateTRTLLM: batch_size = scheduled_requests.batch_size - beam_width = self.beam_width(scheduled_requests.all_requests) self.setup_sampler_step(scheduled_requests.context_requests) @@ -635,20 +581,6 @@ def sample_async(self, scheduled_requests: ScheduledRequests, self.algs.decoder.forward_async(self.store["decoder_state"], decoding_input) - # NOTE: The following code prepares a new_tokens_device_tensor in accordance with the - # current implementation of model_engine. - # TODO: When we support speculative decoding: - # new_tokens_device_tensor should be, for speculative decoding cases: [batch, 1 + draft_len], others: [batch] - new_tokens_device_tensor = self.store[ - "new_tokens_device_tensor"][:batch_size, :beam_width] - seq_slots = [ - request.seq_slot for request in scheduled_requests.all_requests - ] - new_tokens_device_tensor.copy_( - self.store["decoder_state"].all_new_tokens[0][seq_slots], - non_blocking=True) - new_tokens_device_tensor = new_tokens_device_tensor.view(-1) - new_output_tokens = self.store["decoder_state"].all_new_tokens.to( 'cpu', non_blocking=True) finished_sum = self.store["decoder_state"].finished_sum.to( @@ -658,16 +590,17 @@ def sample_async(self, scheduled_requests: ScheduledRequests, sequence_lengths = self.store["decoder_state"].sequence_lengths.to( 'cpu', non_blocking=True) - log_probs = torch.empty([0], dtype=torch.float, device='cpu') - cum_log_probs = torch.empty([0], dtype=torch.float, device='cpu') + log_probs = None + cum_log_probs = None if any(request.py_return_log_probs - for request in scheduled_requests.all_requests): + for request in scheduled_requests.all_requests()): log_probs = self.store["decoder_state"].log_probs.to( 'cpu', non_blocking=True) cum_log_probs = self.store["decoder_state"].cum_log_probs.to( 'cpu', non_blocking=True) - device = SampleStateTensors(new_tokens=new_tokens_device_tensor) + device = SampleStateTensors( + new_tokens=self.store["decoder_state"].all_new_tokens) host = SampleStateTensorsHostTRTLLM(new_tokens=new_output_tokens, finished_sum=finished_sum, @@ -680,7 +613,6 @@ def sample_async(self, scheduled_requests: ScheduledRequests, sampler_event.record() return SampleStateTRTLLM(scheduled_requests=scheduled_requests, - logits=model_outputs["logits"], device=device, host=host, sampler_event=sampler_event) @@ -690,7 +622,8 @@ def update_requests(self, state: SampleStateTRTLLM): scheduled_requests = state.scheduled_requests assert scheduled_requests.batch_size > 0 - beam_width = self.beam_width(scheduled_requests.all_requests) + requests = scheduled_requests.all_requests() + beam_width = self.beam_width(requests) sampler_event = state.sampler_event if sampler_event: @@ -701,7 +634,7 @@ def update_requests(self, state: SampleStateTRTLLM): finish_reasons_host = state.host.finish_reasons sequence_lengths_host_data = state.host.sequence_lengths - for request in scheduled_requests.all_requests: + for request in requests: if request.is_context_init_state: continue @@ -722,17 +655,20 @@ def update_requests(self, state: SampleStateTRTLLM): seq_len - request.get_num_tokens(beam)) for step in range(num_new_tokens[beam]): - new_token = new_tokens_host[step][seq_slot][beam] - request.add_new_token(new_token, beam) + new_token = add_token(request, + new_tokens_host, + beam=beam, + step=step) if request.py_return_log_probs: + assert state.host.log_probs is not None # NOTE: Log probs with drafting has not been tested yet. begin_log_probs_offset = request.prompt_len if request.sampling_config.beam_width == 1 else 0 current_token = seq_len - request.prompt_len - num_new_tokens[ beam] + step log_probs.append({ - new_token.item(): + new_token: Logprob(logprob=state.host.log_probs[seq_slot][beam] [begin_log_probs_offset + current_token].item(), diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler.py b/tensorrt_llm/_torch/pyexecutor/scheduler.py index 9ce25061427..26df44874a0 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod from collections import namedtuple -from itertools import chain from typing import Optional from tensorrt_llm.bindings import executor as tb_executor @@ -36,9 +35,8 @@ def can_run_cuda_graph(self) -> bool: def batch_size(self) -> int: return len(self.context_requests) + len(self.generation_requests) - @property - def all_requests(self) -> chain[LlmRequest]: - return chain(self.context_requests, self.generation_requests) + def all_requests(self) -> list[LlmRequest]: + return self.context_requests + self.generation_requests class RequestScheduler(ABC): diff --git a/tensorrt_llm/_torch/pyexecutor/seq_slot_manager.py b/tensorrt_llm/_torch/pyexecutor/seq_slot_manager.py index 523a9693326..2dfe1737467 100644 --- a/tensorrt_llm/_torch/pyexecutor/seq_slot_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/seq_slot_manager.py @@ -1,5 +1,3 @@ -import itertools - from .llm_request import LlmRequest from .resource_manager import BaseResourceManager, SlotManager from .scheduler import ScheduledRequests @@ -17,10 +15,8 @@ def get_needed_resource_to_completion(self, request: LlmRequest) -> int: return 1 def prepare_resources(self, scheduled_batch: ScheduledRequests) -> None: - for llm_req in itertools.chain(scheduled_batch.context_requests, - scheduled_batch.generation_requests): - if (llm_req.is_context_init_state and llm_req.seq_slot is None) or \ - llm_req.is_disagg_generation_transmission_complete: + for llm_req in scheduled_batch.all_requests(): + if llm_req.seq_slot is None or llm_req.is_disagg_generation_transmission_complete: llm_req.seq_slot = self.slot_manager.add_slot( llm_req.request_id) diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index e6183cc1528..3ed84781036 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -10,7 +10,7 @@ from ..attention_backend import AttentionMetadata from ..pyexecutor.llm_request import LlmRequest from ..pyexecutor.resource_manager import BaseResourceManager, SlotManager -from ..pyexecutor.sampler import SampleState, SampleStateTensors, TorchSampler +from ..pyexecutor.sampler import TorchSampler from ..pyexecutor.scheduler import ScheduledRequests from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode from .mtp import MTPSampler @@ -214,26 +214,6 @@ def get_hidden_states(self): return hidden_states -class Eagle3Sampler(TorchSampler): - - def _batch_sample(self, scheduled_requests, model_outputs) -> SampleState: - logits = model_outputs["logits"] - new_tokens_device = torch.argmax(logits, dim=-1) - if "d2t" in model_outputs: - d2t = model_outputs["d2t"] - new_tokens_device = d2t[new_tokens_device] + new_tokens_device - device = SampleStateTensors(new_tokens=new_tokens_device) - host = SampleStateTensors( - new_tokens=new_tokens_device.to('cpu', non_blocking=True)) - sampler_event = torch.cuda.Event() - sampler_event.record() - return SampleState(scheduled_requests=scheduled_requests, - logits=logits, - device=device, - host=host, - sampler_event=sampler_event) - - @dataclass class Eagle3OneModelSpecMetadata(SpecMetadata): # The hidden states @@ -299,31 +279,10 @@ def maybe_capture_hidden_states( break -class Eagle3Decoder(TorchSampler): - - def _batch_sample(self, scheduled_requests, model_outputs) -> SampleState: - logits = model_outputs["logits"] - new_tokens_device = torch.argmax(logits, dim=-1) - if "d2t" in model_outputs: - d2t = model_outputs["d2t"] - new_tokens_device = d2t[new_tokens_device] + new_tokens_device - new_tokens_host = new_tokens_device.to('cpu', non_blocking=True) - new_tensors_device = {"new_tokens_device": new_tokens_device} - new_tensors_host = {"new_tokens_host": new_tokens_host} - decoder_event = torch.cuda.Event() - decoder_event.record() - return SampleState(scheduled_requests=scheduled_requests, - logits=logits, - new_tensors_device=new_tensors_device, - new_tensors_host=new_tensors_host, - decoder_event=decoder_event) - - -class Eagle3OneModelDecoder(MTPSampler): +class Eagle3OneModelSampler(MTPSampler): - def __init__(self, max_seq_len: int, config: Eagle3Config): - super().__init__(max_seq_len, None) - self.draft_len = config.max_draft_tokens + def __init__(self, args: TorchSampler.Args): + super().__init__(args, nextn=args.max_draft_tokens) class Eagle3OneModelWorker(nn.Module): diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 25edbdae363..f5e432690ee 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -14,7 +14,7 @@ from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode -@dataclass(frozen=True, kw_only=True) +@dataclass(kw_only=True) class SampleStateTensorsMTP(SampleStateTensors): new_tokens_lens: torch.Tensor next_draft_tokens: torch.Tensor @@ -248,12 +248,10 @@ class MTPSampler(TorchSampler): SampleState = SampleStateMTP - def __init__(self, max_seq_len: int, config: MTPConfig): - super().__init__(max_seq_len, False) + def __init__(self, args: TorchSampler.Args, *, nextn: int): + super().__init__(args) self.mapping = None - self.draft_len = 0 - if config is not None: - self.draft_len = config.num_nextn_predict_layers + self.draft_len = nextn def _draft_meet_max_token_stop_criteria(self, request: LlmRequest, num_tokens: int, beam_idx: int): @@ -283,8 +281,9 @@ def update_requests(self, state: SampleStateMTP) -> None: if request.state != LlmRequestState.GENERATION_COMPLETE: new_token = new_tokens_list[idx][0] num_tokens = request.add_new_token(new_token, beam_idx) - should_stop = self._handle_stop_criteria( - request, new_token, num_tokens, beam_idx) + should_stop = self._handle_stop_criteria(request, + new_token, + beam=beam_idx) if self._draft_meet_max_token_stop_criteria( request, num_tokens, beam_idx): should_stop = True @@ -303,8 +302,9 @@ def update_requests(self, state: SampleStateMTP) -> None: for i in range(num_new_tokens): new_token = new_tokens[i] num_tokens = request.add_new_token(new_token, beam_idx) - should_stop = self._handle_stop_criteria( - request, new_token, num_tokens, beam_idx) + should_stop = self._handle_stop_criteria(request, + new_token, + beam=beam_idx) if should_stop: break if self._draft_meet_max_token_stop_criteria( @@ -344,7 +344,6 @@ def sample_async(self, scheduled_requests: ScheduledRequests, for request in scheduled_requests.context_requests: request.py_draft_tokens = [1] * self.draft_len return SampleStateMTP(scheduled_requests=scheduled_requests, - logits=model_outputs['logits'], device=device, host=host, sampler_event=sampler_event) diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index 3dd49bb108f..85b2bf46e9c 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -1,6 +1,9 @@ +from tensorrt_llm._torch.pyexecutor.sampler import TorchSampler +from tensorrt_llm._torch.speculative.interface import SpecConfig + from .draft_target import DraftTargetSpecMetadata -from .eagle3 import (Eagle3OneModelDecoder, Eagle3OneModelSpecMetadata, - Eagle3OneModelWorker, Eagle3ResourceManager, Eagle3Sampler, +from .eagle3 import (Eagle3OneModelSampler, Eagle3OneModelSpecMetadata, + Eagle3OneModelWorker, Eagle3ResourceManager, Eagle3SpecMetadata) from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler, MTPSpecMetadata, MTPWorker) @@ -77,15 +80,17 @@ def get_spec_resource_manager(spec_config, return None -def get_spec_decoder(max_seq_len, spec_config): +def get_spec_decoder(sampler_args: TorchSampler.Args, spec_config: SpecConfig): if spec_config.spec_dec_mode.is_mtp(): - return MTPSampler(max_seq_len, spec_config) - elif spec_config.spec_dec_mode.is_eagle3(): - return Eagle3Sampler(max_seq_len) - elif spec_config.spec_dec_mode.is_eagle3_one_model(): - return Eagle3OneModelDecoder(max_seq_len, spec_config) - else: - return None + return MTPSampler(sampler_args, + nextn=spec_config.num_nextn_predict_layers) + if spec_config.spec_dec_mode.is_eagle3(): + # TorchSampler handles Eagle3 gracefully, by integrating d2t into the sampling process + return TorchSampler(sampler_args) + if spec_config.spec_dec_mode.is_eagle3_one_model(): + return Eagle3OneModelSampler(sampler_args) + raise ValueError( + f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}") def get_num_spec_layers(spec_config): From 84138c69f29d0fd117fcd23f3c044b360a38f7f8 Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Thu, 26 Jun 2025 11:14:10 +0000 Subject: [PATCH 2/4] support all features Signed-off-by: Netanel Haber minimize diff Signed-off-by: Netanel Haber minimize diff Signed-off-by: Netanel Haber --- tensorrt_llm/_torch/pyexecutor/sampler.py | 100 +++++++++++++++------- 1 file changed, 71 insertions(+), 29 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index 885dd0c47b6..99859f79d06 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from collections.abc import Iterable from dataclasses import dataclass +from typing import Literal import torch @@ -155,22 +156,35 @@ def greedy_search_sampling_batch(logits): return next_tokens, softmax -def sample_single_request(request: LlmRequest, logits: torch.Tensor): - assert logits.dim( - ) == 2 and logits.shape[0] == 1, "logits should have shape [1, vocab_size]" - if request.sampling_config.top_p is not None and len( - request.sampling_config.top_p) > 0: - return top_p_sampling_batch(logits, request.sampling_config.top_p[0]) - elif request.sampling_config.top_k is not None and len( +TopK = tuple[Literal["top_k"], int] +TopP = tuple[Literal["top_p"], float] +Greedy = tuple[Literal["greedy"], None] +Strategy = TopK | TopP | Greedy + + +def request_strategy(request: LlmRequest) -> Strategy: + if request.sampling_config.top_k is not None and len( request.sampling_config.top_k) > 0: - return top_k_sampling_batch(logits, request.sampling_config.top_k[0]) + return ("top_k", request.sampling_config.top_k[0]) + elif request.sampling_config.top_p is not None and len( + request.sampling_config.top_p) > 0: + return ("top_p", request.sampling_config.top_p[0]) else: - return greedy_search_sampling_batch(logits) + return ("greedy", None) + + +def sampling_strategies(requests: Iterable[LlmRequest]) -> list[Strategy]: + return [request_strategy(req) for req in requests] -def new_tokens_slice(request: LlmRequest, beam: int, *, - size: int) -> tuple[slice, int, int]: - return slice(0, size), request.seq_slot, beam +def sample(strategy: Strategy, logits: torch.Tensor): + match strategy: + case ("top_k", top_k): + return top_k_sampling_batch(logits, top_k) + case ("top_p", top_p): + return top_p_sampling_batch(logits, top_p) + case ("greedy", None): + return greedy_search_sampling_batch(logits) def add_token(request: LlmRequest, @@ -265,7 +279,7 @@ def _handle_stop_criteria(self, request: LlmRequest, new_token: int, *, def handle_logits(self, request: LlmRequest, state: SampleState, *, beam: int, count: int): - current_slice = new_tokens_slice(request, beam, size=count) + current_slice = slice(0, count), request.seq_slot, beam if request.py_return_generation_logits: assert state.host.logits is not None current_logits = state.host.logits[current_slice] @@ -365,6 +379,12 @@ def sample_async(self, scheduled_requests: ScheduledRequests, logits=gen_logits_host), sampler_event=sampler_event) + @staticmethod + def append_eagle3(tokens: torch.Tensor, model_outputs): + if "d2t" in model_outputs: + d2t = model_outputs["d2t"][tokens] + tokens += d2t + def _process_requests(self, requests: list[LlmRequest], model_outputs: dict[str, torch.Tensor], @@ -372,25 +392,47 @@ def _process_requests(self, *, gen_logits_host: torch.Tensor | None = None, log_probs_host: torch.Tensor | None = None): + beam_width = self.MAX_BEAM_WIDTH beam = self.BEAM - offset = 0 raw_logits = model_outputs["logits"] + num_steps = [1 + len(req.py_draft_tokens) for req in requests] + sum_steps = sum(num_steps) + no_draft_tokens = len(requests) == sum_steps + fast_path = not self.mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None + + seq_slots = torch.as_tensor([r.seq_slot for r in requests]) + seq_slots = seq_slots.to(device="cuda", non_blocking=True) + + if fast_path: + logits = raw_logits[:len(requests)] + next_tokens = torch.argmax(logits, dim=-1) + self.append_eagle3(next_tokens, model_outputs) + int_next_tokens = next_tokens.to(torch.int, non_blocking=True) + next_tokens = int_next_tokens.view(1, -1, beam_width) + new_tokens[:1].index_copy_(1, seq_slots, next_tokens) + return + + batched_next_tokens, batched_softmax = None, None + strategies = sampling_strategies(requests) + if len(set(strategies)) == 1: + logits = raw_logits[:sum_steps] + strategy = strategies[0] + batched_next_tokens, batched_softmax = sample(strategy, logits) + self.append_eagle3(batched_next_tokens, model_outputs) + else: + assert "d2t" not in model_outputs, "eagle3 does not yet support non-uniform sampling" - for request in requests: - steps = 1 - if len(request.py_draft_tokens) > 0: - assert not self.mixed_sampler, "Speculative decoding not supported in mixed sampler" - steps += len(request.py_draft_tokens) - logits = raw_logits[offset:offset + steps] - if self.mixed_sampler: - next_tokens, softmax = sample_single_request(request, logits) + offset = 0 + for strategy, slot, steps in zip(strategies, seq_slots, num_steps): + input_slice = slice(offset, offset + steps) + logits = raw_logits[input_slice] + if batched_next_tokens is None: + next_tokens, softmax = sample(strategy, logits) else: - next_tokens, softmax = greedy_search_sampling_batch(logits) - current_slice = new_tokens_slice(request, beam, size=steps) + next_tokens = batched_next_tokens[input_slice] + softmax = batched_softmax[input_slice] + current_slice = slice(0, steps), slot, beam new_tokens[current_slice] = next_tokens - if "d2t" in model_outputs: # Eagle3 - new_tokens[current_slice] += model_outputs["d2t"][ - new_tokens[current_slice]] if gen_logits_host is not None: gen_logits_host[current_slice].copy_(logits, non_blocking=True) if log_probs_host is not None: @@ -398,8 +440,8 @@ def _process_requests(self, token_probs = torch.gather( softmax, dim=1, index=next_tokens.unsqueeze(1)).squeeze(-1) log_probs = torch.log(token_probs) - log_probs_host[request.seq_slot, - beam, :steps].copy_(log_probs, non_blocking=True) + log_probs_host[slot, beam, :steps].copy_(log_probs, + non_blocking=True) offset += steps From 051fe4af69d2c699ba5dc37c8d9fc2ec30d6929c Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Mon, 30 Jun 2025 13:26:51 +0000 Subject: [PATCH 3/4] revert behavior change back to: if non mixed sampling, always greedy sampling Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/sampler.py | 27 ++++++++++++++--------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index e3181220abb..ca7f72250f2 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -158,16 +158,17 @@ def greedy_search_sampling_batch(logits): TopK = tuple[Literal["top_k"], int] TopP = tuple[Literal["top_p"], float] Greedy = tuple[Literal["greedy"], None] +GREEDY: Greedy = ("greedy", None) Strategy = TopK | TopP | Greedy def request_strategy(request: LlmRequest) -> Strategy: - if request.sampling_config.top_k is not None and len( - request.sampling_config.top_k) > 0: - return ("top_k", request.sampling_config.top_k[0]) - elif request.sampling_config.top_p is not None and len( + if request.sampling_config.top_p is not None and len( request.sampling_config.top_p) > 0: return ("top_p", request.sampling_config.top_p[0]) + elif request.sampling_config.top_k is not None and len( + request.sampling_config.top_k) > 0: + return ("top_k", request.sampling_config.top_k[0]) else: return ("greedy", None) @@ -412,14 +413,20 @@ def _process_requests(self, return batched_next_tokens, batched_softmax = None, None - strategies = sampling_strategies(requests) - if len(set(strategies)) == 1: + batched_strategy: Strategy | None = GREEDY + if self.mixed_sampler: + strategies = sampling_strategies(requests) + assert "d2t" not in model_outputs, "eagle3 does not yet support non-greedy sampling" + if len(set(strategies)) == 1: + batched_strategy = strategies[0] + else: + batched_strategy = None + + if batched_strategy is not None: logits = raw_logits[:sum_steps] - strategy = strategies[0] - batched_next_tokens, batched_softmax = sample(strategy, logits) + batched_next_tokens, batched_softmax = sample( + batched_strategy, logits) self.append_eagle3(batched_next_tokens, model_outputs) - else: - assert "d2t" not in model_outputs, "eagle3 does not yet support non-uniform sampling" offset = 0 for strategy, slot, steps in zip(strategies, seq_slots, num_steps): From dd73678e8465705a048d9f0ae0c13c6b434f80f0 Mon Sep 17 00:00:00 2001 From: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Date: Mon, 30 Jun 2025 13:32:40 +0000 Subject: [PATCH 4/4] revert behavior change back to: if non mixed sampling, always greedy sampling Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> --- tensorrt_llm/_torch/pyexecutor/sampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index ca7f72250f2..182b2344d71 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -412,10 +412,10 @@ def _process_requests(self, new_tokens[:1].index_copy_(1, seq_slots, next_tokens) return + strategies = sampling_strategies(requests) batched_next_tokens, batched_softmax = None, None batched_strategy: Strategy | None = GREEDY if self.mixed_sampler: - strategies = sampling_strategies(requests) assert "d2t" not in model_outputs, "eagle3 does not yet support non-greedy sampling" if len(set(strategies)) == 1: batched_strategy = strategies[0]