Skip to content

Commit f021f5e

Browse files
committed
Fix a few issues
Signed-off-by: ziyixiong-nv <[email protected]>
1 parent 23d55b4 commit f021f5e

File tree

4 files changed

+180
-166
lines changed

4 files changed

+180
-166
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,9 @@
4141
from .py_executor import PyExecutor
4242

4343

44-
# Development flag to control chain drafter feature
44+
# Development function to control chain drafter feature
4545
def _get_allow_chain_drafter() -> bool:
46-
"""Get the chain drafter flag from environment variable."""
47-
# Use environment variable for cross-process compatibility
48-
return os.getenv("TRTLLM_ALLOW_CHAIN_DRAFTER", "1") != "0"
46+
return True
4947

5048

5149
class _ExecutorCreationStage(enum.Enum):

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ def _convert_draft_tensors(
396396
new_tokens_lens = None
397397
next_draft_tokens = None
398398
has_draft_tokens = False
399+
batch_size = new_tokens.shape[1]
399400
# Iterate through generation requests and copy tokens based on accepted draft tokens
400401
for request in scheduled_batch.all_requests():
401402
idx = request.py_seq_slot
@@ -411,9 +412,8 @@ def _convert_draft_tensors(
411412

412413
if has_draft_tokens:
413414
# We already updated the target state, so the new_tokens_lens should be all ones.
414-
new_tokens_lens = torch.ones(scheduled_batch.batch_size,
415-
device=device)
416-
next_draft_tokens = torch.zeros(scheduled_batch.batch_size,
415+
new_tokens_lens = torch.ones(batch_size, device=device)
416+
next_draft_tokens = torch.zeros(batch_size,
417417
self.max_draft_tokens,
418418
device=device)
419419

@@ -438,13 +438,14 @@ def _update_target_inputs_with_draft_tokens(
438438
Update target inputs with new draft tokens from sample state.
439439
"""
440440
if draft_tensors is not None:
441-
for request in draft_batch.all_requests():
441+
for req_idx, request in enumerate(draft_batch.all_requests()):
442442
# Skip prefill requests
443443
if target_inputs.next_draft_tokens is None:
444444
continue
445445

446446
# Get the index of the draft/target tokens in the device tensor
447-
draft_idx = request.py_seq_slot
447+
# For static draft loops, use the enumerated index; for dynamic loops, use py_seq_slot
448+
draft_idx = req_idx if self.use_static_draft_loop else request.py_seq_slot
448449
target_idx = req_id_to_old_request[
449450
request.py_request_id].py_seq_slot
450451
target_inputs.new_tokens[draft_position + 1:draft_position +

tests/unittest/_torch/speculative/test_dynamic_spec_decode.py

Lines changed: 89 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import sys
33
import unittest
4-
from unittest.mock import patch
4+
from unittest.mock import Mock, patch
55

66
import pytest
77
import torch
@@ -15,101 +15,97 @@
1515
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
1616

1717

18+
@pytest.fixture(scope="function")
19+
def enforce_single_worker(monkeypatch):
20+
monkeypatch.setenv("TLLM_WORKER_USE_SINGLE_PROCESS", "1")
21+
yield
22+
23+
1824
@pytest.mark.parametrize("disable_overlap_scheduler", [True, False])
1925
@pytest.mark.high_cuda_memory
20-
def test_dynamic_spec_decode(disable_overlap_scheduler: bool):
21-
# Store original value and set environment variable
22-
original_value = os.environ.get("TLLM_WORKER_USE_SINGLE_PROCESS")
26+
def test_dynamic_spec_decode(enforce_single_worker,
27+
disable_overlap_scheduler: bool):
2328
# mock_should_use_spec_decode doesn't work with multiple processes,
24-
# so we set the environment variable to 1 in this test.
25-
os.environ["TLLM_WORKER_USE_SINGLE_PROCESS"] = "1"
26-
27-
try:
28-
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
29-
if total_mem_gb < 35:
30-
pytest.skip("Not enough memory to load target + draft model")
31-
32-
models_path = llm_models_root()
33-
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
34-
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
35-
36-
max_batch_size = 1
37-
max_draft_len = 4
38-
kv_cache_config = KvCacheConfig(enable_block_reuse=True,
39-
max_tokens=8192)
40-
cuda_graph_config = CudaGraphConfig(batch_sizes=[1])
41-
42-
llm_common_config = dict(
43-
model=target_model_dir,
44-
attn_backend="TRTLLM",
45-
disable_overlap_scheduler=disable_overlap_scheduler,
46-
cuda_graph_config=cuda_graph_config,
47-
max_batch_size=max_batch_size,
48-
kv_cache_config=kv_cache_config,
49-
# This max_seq_len is larger than the one specified
50-
# in the llama 3 8B eagle's config. We want to make sure
51-
# that the draft model won't go above its max in warmup
52-
# in this test.
53-
max_seq_len=8192,
54-
)
55-
56-
spec_config = EagleDecodingConfig(
57-
max_draft_len=max_draft_len,
58-
speculative_model_dir=eagle_model_dir,
59-
# Llama 3 does not support one model eagle.
60-
eagle3_one_model=False,
61-
)
62-
63-
# Mock should_use_spec_decode to turn on/off spec decode dynamically.
64-
def mock_should_use_spec_decode(requests, max_batch_size,
65-
max_num_tokens, max_draft_len):
66-
if not hasattr(mock_should_use_spec_decode, 'call_count'):
67-
mock_should_use_spec_decode.call_count = 0
68-
69-
for req in requests:
70-
if req.state != LlmRequestState.GENERATION_IN_PROGRESS:
71-
continue
72-
73-
mock_should_use_spec_decode.call_count += 1
74-
# Turn off spec decode when we've called it 5 times.
75-
# In the current case, at the 5th call, there are 2 accepted draft tokens,
76-
# so we can have better coverage for the switching between spec decode on and off.
77-
if mock_should_use_spec_decode.call_count > 5:
78-
return False
79-
return True
80-
81-
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
82-
sampling_params = SamplingParams(max_tokens=128, temperature=0)
83-
84-
# Output tests
85-
prompts = [
86-
"The president of the United States is",
87-
]
88-
sampling_params = SamplingParams(max_tokens=20, temperature=0)
89-
90-
with patch(
91-
'tensorrt_llm._torch.speculative.model_drafter.ModelDrafter.should_use_spec_decode',
92-
side_effect=mock_should_use_spec_decode):
93-
results_spec = llm_spec.generate(prompts, sampling_params)
94-
generated_text_spec = [
95-
result.outputs[0].text for result in results_spec
96-
]
97-
llm_spec.shutdown()
98-
99-
llm_ref = LLM(**llm_common_config)
100-
results_ref = llm_ref.generate(prompts, sampling_params)
101-
generated_text_ref = [result.outputs[0].text for result in results_ref]
102-
llm_ref.shutdown()
103-
104-
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
105-
# The spec decode algorithm currently guarantees identical results
106-
assert text_spec == text_ref
107-
finally:
108-
# Restore original environment variable value
109-
if original_value is None:
110-
os.environ.pop("TLLM_WORKER_USE_SINGLE_PROCESS", None)
111-
else:
112-
os.environ["TLLM_WORKER_USE_SINGLE_PROCESS"] = original_value
29+
# so we use the enforce_single_worker fixture to set the environment variable.
30+
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
31+
if total_mem_gb < 35:
32+
pytest.skip("Not enough memory to load target + draft model")
33+
34+
models_path = llm_models_root()
35+
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
36+
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
37+
38+
max_batch_size = 1
39+
max_draft_len = 4
40+
kv_cache_config = KvCacheConfig(enable_block_reuse=True, max_tokens=8192)
41+
cuda_graph_config = CudaGraphConfig(batch_sizes=[1])
42+
43+
llm_common_config = dict(
44+
model=target_model_dir,
45+
attn_backend="TRTLLM",
46+
disable_overlap_scheduler=disable_overlap_scheduler,
47+
cuda_graph_config=cuda_graph_config,
48+
max_batch_size=max_batch_size,
49+
kv_cache_config=kv_cache_config,
50+
# This max_seq_len is larger than the one specified
51+
# in the llama 3 8B eagle's config. We want to make sure
52+
# that the draft model won't go above its max in warmup
53+
# in this test.
54+
max_seq_len=8192,
55+
)
56+
57+
spec_config = EagleDecodingConfig(
58+
max_draft_len=max_draft_len,
59+
speculative_model_dir=eagle_model_dir,
60+
# Llama 3 does not support one model eagle.
61+
eagle3_one_model=False,
62+
)
63+
64+
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
65+
sampling_params = SamplingParams(max_tokens=128, temperature=0)
66+
67+
# Output tests
68+
prompts = [
69+
"The president of the United States is",
70+
]
71+
sampling_params = SamplingParams(max_tokens=20, temperature=0)
72+
73+
# Mock should_use_spec_decode to turn on/off spec decode dynamically.
74+
def mock_should_use_spec_decode(requests, max_batch_size, max_num_tokens,
75+
max_draft_len):
76+
for req in requests:
77+
if req.state != LlmRequestState.GENERATION_IN_PROGRESS:
78+
continue
79+
80+
mock_should_use_spec_decode.call_count += 1
81+
# Turn off spec decode when we've called it 5 times.
82+
# In the current case, at the 5th call, there are 2 accepted draft tokens,
83+
# so we can have better coverage for the switching between spec decode on and off.
84+
if mock_should_use_spec_decode.call_count > 5:
85+
return False
86+
return True
87+
88+
# Create a Mock object with the mock function as side_effect
89+
mock_should_use_spec_decode = Mock(side_effect=mock_should_use_spec_decode)
90+
# Reset mock state before using it
91+
mock_should_use_spec_decode.reset_mock()
92+
mock_should_use_spec_decode.call_count = 0
93+
94+
with patch(
95+
'tensorrt_llm._torch.speculative.model_drafter.ModelDrafter.should_use_spec_decode',
96+
mock_should_use_spec_decode):
97+
results_spec = llm_spec.generate(prompts, sampling_params)
98+
generated_text_spec = [result.outputs[0].text for result in results_spec]
99+
llm_spec.shutdown()
100+
101+
llm_ref = LLM(**llm_common_config)
102+
results_ref = llm_ref.generate(prompts, sampling_params)
103+
generated_text_ref = [result.outputs[0].text for result in results_ref]
104+
llm_ref.shutdown()
105+
106+
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
107+
# The spec decode algorithm currently guarantees identical results
108+
assert text_spec == text_ref
113109

114110

115111
def test_should_use_spec_decode():

0 commit comments

Comments
 (0)