Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,10 @@ def __init__(
else:
self.cache_indirection_attention = None

@property
def runtime_draft_len(self):
return self.max_draft_len if self.enable_spec_decode else 0

def set_lora_model_config(self, lora_target_modules: list[str],
trtllm_modules_to_hf_modules: dict[str, str]):
self.lora_model_config = LoraModelConfig(
Expand Down Expand Up @@ -557,7 +561,7 @@ def get_torch_compile_warmup_request(batch_size,
list(range(batch_size)), [num_tokens_per_request] *
batch_size if not is_gen else None,
is_gen=is_gen,
max_num_draft_tokens=self.max_draft_len)
max_num_draft_tokens=self.runtime_draft_len)

if spec_resource_manager is not None:
spec_resource_manager.add_dummy_requests(
Expand All @@ -576,7 +580,7 @@ def get_torch_compile_warmup_request(batch_size,

def get_autotune_warmup_request():
available_tokens = kv_cache_manager.get_num_available_tokens(
self.max_draft_len)
self.runtime_draft_len)
num_tokens_per_request = min(
min(available_tokens, self.max_seq_len - 1),
self.max_num_tokens)
Expand Down Expand Up @@ -610,14 +614,14 @@ def get_autotune_warmup_request():
request_ids=list(range(full_len_request_num)),
token_nums=[num_tokens_per_request] * full_len_request_num,
is_gen=False,
max_num_draft_tokens=self.max_draft_len)
max_num_draft_tokens=self.runtime_draft_len)

if remaining_tokens > 0:
final_request = kv_cache_manager.add_dummy_requests(
request_ids=[full_len_request_num],
token_nums=[remaining_tokens],
is_gen=False,
max_num_draft_tokens=self.max_draft_len)
max_num_draft_tokens=self.runtime_draft_len)

requests += final_request

Expand Down Expand Up @@ -664,7 +668,7 @@ def disable_optimization(backend: Backend):
# Disable cuda graph capture here so that we can properly capture it later
with self.no_cuda_graph():
available_tokens = kv_cache_manager.get_num_available_tokens(
self.max_draft_len)
self.runtime_draft_len)
warmup_batch_size = [1, self.batch_size // 2]
if self.batch_size < 2:
warmup_batch_size = [1]
Expand Down Expand Up @@ -879,7 +883,7 @@ def _get_padded_batch(
self.cuda_graph_dummy_request = kv_cache_manager.add_dummy_requests(
cuda_graph_dummy_request_ids,
is_gen=True,
max_num_draft_tokens=self.max_draft_len,
max_num_draft_tokens=self.runtime_draft_len,
use_mrope=self.use_mrope,
max_beam_width=self.max_beam_width)[0]
self.cuda_graph_dummy_request.is_cuda_graph_dummy = True
Expand Down Expand Up @@ -1306,7 +1310,7 @@ def _prepare_tp_inputs(
gather_ids.extend(
list(
range(len(position_ids),
len(position_ids) + 1 + self.max_draft_len)))
len(position_ids) + 1 + self.runtime_draft_len)))
position_ids.extend(
list(
range(past_seen_token_num,
Expand All @@ -1322,23 +1326,23 @@ def _prepare_tp_inputs(
# inputs
# overlap scheduler can only support the speculative decoding
# methods with a fixed number of draft tokens
sequence_lengths.append(1 + self.max_draft_len)
sequence_lengths.append(1 + self.runtime_draft_len)
past_seen_token_num = request.max_beam_num_tokens - 1
draft_lens.append(self.max_draft_len)
draft_lens.append(self.runtime_draft_len)
gather_ids.extend(
list(
range(len(position_ids),
len(position_ids) + 1 + self.max_draft_len)))
len(position_ids) + 1 + self.runtime_draft_len)))
position_ids.extend(
list(
range(past_seen_token_num,
past_seen_token_num + 1 + self.max_draft_len)))
range(past_seen_token_num, past_seen_token_num + 1 +
self.runtime_draft_len)))
# previous tensor
previous_batch_indices.append(previous_batch_idx)
previous_pos_indices.extend([previous_batch_idx] *
(1 + self.max_draft_len))
(1 + self.runtime_draft_len))
num_cached_tokens_per_seq.append(past_seen_token_num +
self.max_draft_len + 1)
self.runtime_draft_len + 1)
prompt_lengths.append(request.py_prompt_len)
request_ids.append(request.py_request_id)

Expand Down Expand Up @@ -1412,21 +1416,21 @@ def previous_seq_slots_device():
previous_slots = previous_seq_slots_device()
# previous input ids
previous_batch_tokens = previous_batch_len * (
1 + self.max_draft_len)
1 + self.runtime_draft_len)
new_tokens = new_tokens_device.transpose(
0, 1)[previous_slots, :].flatten()
self.input_ids_cuda[num_tokens:num_tokens +
previous_batch_tokens].copy_(
new_tokens, non_blocking=True)
# previous draft tokens
previous_batch_draft_tokens = previous_batch_len * self.max_draft_len
previous_batch_draft_tokens = previous_batch_len * self.runtime_draft_len
self.draft_tokens_cuda[num_draft_tokens:num_draft_tokens +
previous_batch_draft_tokens].copy_(
next_draft_tokens_device[
previous_slots, :].flatten(),
non_blocking=True)
# prepare data for the preprocess inputs
kv_len_offsets_device = new_tokens_lens_device - self.max_draft_len - 1
kv_len_offsets_device = new_tokens_lens_device - self.runtime_draft_len - 1
previous_pos_indices_host = torch.tensor(previous_pos_indices,
dtype=torch.int,
pin_memory=True)
Expand All @@ -1451,8 +1455,8 @@ def previous_seq_slots_device():
extend_dummy_requests)
self.previous_pos_id_offsets_cuda[
(num_extend_reqeust_wo_dummy - previous_batch_len) *
(1 + self.max_draft_len):num_extend_reqeust_wo_dummy *
(1 + self.max_draft_len)].copy_(
(1 + self.runtime_draft_len):num_extend_reqeust_wo_dummy *
(1 + self.runtime_draft_len)].copy_(
new_tokens_lens_device[self.previous_pos_indices_cuda[
0:previous_batch_tokens]],
non_blocking=True)
Expand Down
10 changes: 5 additions & 5 deletions tensorrt_llm/_torch/pyexecutor/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ def process_draft_tokens(self, request: LlmRequest,
if request.py_draft_logits is None:
new_token = add_token(request, new_tokens, beam=self.BEAM)
stop = self._handle_stop_criteria(request, new_token)
if stop or len(request.py_draft_tokens) == 0:
if stop or get_draft_token_length(request) == 0:
return 0
num_accepted = 0

Expand All @@ -360,10 +360,10 @@ def process_draft_tokens(self, request: LlmRequest,
request.py_draft_logits[0],
generator=generator)
target_probs = request.py_target_probs
p = draft_probs[torch.arange(len(request.py_draft_tokens)),
p = draft_probs[torch.arange(get_draft_token_length(request)),
request.py_draft_tokens]
q = target_probs[:-1]
q = q[torch.arange(len(request.py_draft_tokens)),
q = q[torch.arange(get_draft_token_length(request)),
request.py_draft_tokens]
accept_probs = torch.minimum(torch.ones(()), q / p)
# Use deterministic random generation for multi-GPU consistency
Expand All @@ -374,7 +374,7 @@ def process_draft_tokens(self, request: LlmRequest,
sample_last = True
stop = False
if rejected_indices.numel() == 0:
num_initially_accepted = len(request.py_draft_tokens)
num_initially_accepted = get_draft_token_length(request)
sample_last = False
else:
num_initially_accepted = rejected_indices[0].item()
Expand Down Expand Up @@ -575,7 +575,7 @@ def _process_requests(self,
logits = raw_logits[:sum_steps]
# Collect steps per request for batched strategy
steps_per_request = [
1 + len(req.py_draft_tokens) for req in requests
1 + get_draft_token_length(req) for req in requests
]
logits = self._apply_embedding_bias(logits, requests,
steps_per_request)
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/defs/accuracy/accuracy_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def evaluate(self,
spec_dec_algo = None
elif isinstance(llm.args.speculative_config, DecodingBaseConfig):
spec_dec_algo = llm.args.speculative_config.decoding_type
if spec_dec_algo == 'AUTO':
spec_dec_algo = 'NGram'
else:
raise ValueError(
f"Not recognized speculative_config: {llm.args.speculative_config}."
Expand Down
25 changes: 21 additions & 4 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \
IS_TRITON_KERNELS_AVAILABLE
from tensorrt_llm._torch.pyexecutor.config import MoeLoadBalancerConfig
from tensorrt_llm.llmapi import (CudaGraphConfig, EagleDecodingConfig,
KvCacheConfig, MoeConfig, MTPDecodingConfig,
NGramDecodingConfig, SamplingParams,
TorchCompileConfig)
from tensorrt_llm.llmapi import (AutoDecodingConfig, CudaGraphConfig,
EagleDecodingConfig, KvCacheConfig, MoeConfig,
MTPDecodingConfig, NGramDecodingConfig,
SamplingParams, TorchCompileConfig)
from tensorrt_llm.quantization import QuantAlgo

from ..conftest import (llm_models_root, parametrize_with_ids, skip_no_hopper,
Expand Down Expand Up @@ -356,6 +356,23 @@ def test_guided_decoding_with_ngram(self, backend: str, mocker):
task = JsonModeEval(self.MODEL_NAME)
task.evaluate(llm)

@skip_pre_hopper
def test_auto_spec_decode(self):
pytorch_config = {
"cuda_graph_config":
CudaGraphConfig(batch_sizes=[1, 32, 64], enable_padding=True)
}
kv_cache_config = KvCacheConfig(enable_block_reuse=False,
free_gpu_memory_fraction=0.5)
spec_config = AutoDecodingConfig()
with LLM(model=self.MODEL_PATH,
**pytorch_config,
kv_cache_config=kv_cache_config,
speculative_config=spec_config,
max_batch_size=64) as llm:
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)


class TestLlama3_2_1B(LlmapiAccuracyTestHarness):
MODEL_NAME = "meta-llama/Llama-3.2-1B"
Expand Down