|
1 | 1 | import os |
2 | 2 | import sys |
3 | 3 | import unittest |
4 | | -from unittest.mock import patch |
| 4 | +from unittest.mock import Mock, patch |
5 | 5 |
|
6 | 6 | import pytest |
7 | 7 | import torch |
|
15 | 15 | sys.path.append(os.path.join(os.path.dirname(__file__), '..')) |
16 | 16 |
|
17 | 17 |
|
| 18 | +@pytest.fixture(scope="function") |
| 19 | +def enforce_single_worker(monkeypatch): |
| 20 | + monkeypatch.setenv("TLLM_WORKER_USE_SINGLE_PROCESS", "1") |
| 21 | + yield |
| 22 | + |
| 23 | + |
18 | 24 | @pytest.mark.parametrize("disable_overlap_scheduler", [True, False]) |
19 | 25 | @pytest.mark.high_cuda_memory |
20 | | -def test_dynamic_spec_decode(disable_overlap_scheduler: bool): |
21 | | - # Store original value and set environment variable |
22 | | - original_value = os.environ.get("TLLM_WORKER_USE_SINGLE_PROCESS") |
| 26 | +def test_dynamic_spec_decode(enforce_single_worker, |
| 27 | + disable_overlap_scheduler: bool): |
23 | 28 | # mock_should_use_spec_decode doesn't work with multiple processes, |
24 | | - # so we set the environment variable to 1 in this test. |
25 | | - os.environ["TLLM_WORKER_USE_SINGLE_PROCESS"] = "1" |
26 | | - |
27 | | - try: |
28 | | - total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 |
29 | | - if total_mem_gb < 35: |
30 | | - pytest.skip("Not enough memory to load target + draft model") |
31 | | - |
32 | | - models_path = llm_models_root() |
33 | | - eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B" |
34 | | - target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct" |
35 | | - |
36 | | - max_batch_size = 1 |
37 | | - max_draft_len = 4 |
38 | | - kv_cache_config = KvCacheConfig(enable_block_reuse=True, |
39 | | - max_tokens=8192) |
40 | | - cuda_graph_config = CudaGraphConfig(batch_sizes=[1]) |
41 | | - |
42 | | - llm_common_config = dict( |
43 | | - model=target_model_dir, |
44 | | - attn_backend="TRTLLM", |
45 | | - disable_overlap_scheduler=disable_overlap_scheduler, |
46 | | - cuda_graph_config=cuda_graph_config, |
47 | | - max_batch_size=max_batch_size, |
48 | | - kv_cache_config=kv_cache_config, |
49 | | - # This max_seq_len is larger than the one specified |
50 | | - # in the llama 3 8B eagle's config. We want to make sure |
51 | | - # that the draft model won't go above its max in warmup |
52 | | - # in this test. |
53 | | - max_seq_len=8192, |
54 | | - ) |
55 | | - |
56 | | - spec_config = EagleDecodingConfig( |
57 | | - max_draft_len=max_draft_len, |
58 | | - speculative_model_dir=eagle_model_dir, |
59 | | - # Llama 3 does not support one model eagle. |
60 | | - eagle3_one_model=False, |
61 | | - ) |
62 | | - |
63 | | - # Mock should_use_spec_decode to turn on/off spec decode dynamically. |
64 | | - def mock_should_use_spec_decode(requests, max_batch_size, |
65 | | - max_num_tokens, max_draft_len): |
66 | | - if not hasattr(mock_should_use_spec_decode, 'call_count'): |
67 | | - mock_should_use_spec_decode.call_count = 0 |
68 | | - |
69 | | - for req in requests: |
70 | | - if req.state != LlmRequestState.GENERATION_IN_PROGRESS: |
71 | | - continue |
72 | | - |
73 | | - mock_should_use_spec_decode.call_count += 1 |
74 | | - # Turn off spec decode when we've called it 5 times. |
75 | | - # In the current case, at the 5th call, there are 2 accepted draft tokens, |
76 | | - # so we can have better coverage for the switching between spec decode on and off. |
77 | | - if mock_should_use_spec_decode.call_count > 5: |
78 | | - return False |
79 | | - return True |
80 | | - |
81 | | - llm_spec = LLM(**llm_common_config, speculative_config=spec_config) |
82 | | - sampling_params = SamplingParams(max_tokens=128, temperature=0) |
83 | | - |
84 | | - # Output tests |
85 | | - prompts = [ |
86 | | - "The president of the United States is", |
87 | | - ] |
88 | | - sampling_params = SamplingParams(max_tokens=20, temperature=0) |
89 | | - |
90 | | - with patch( |
91 | | - 'tensorrt_llm._torch.speculative.model_drafter.ModelDrafter.should_use_spec_decode', |
92 | | - side_effect=mock_should_use_spec_decode): |
93 | | - results_spec = llm_spec.generate(prompts, sampling_params) |
94 | | - generated_text_spec = [ |
95 | | - result.outputs[0].text for result in results_spec |
96 | | - ] |
97 | | - llm_spec.shutdown() |
98 | | - |
99 | | - llm_ref = LLM(**llm_common_config) |
100 | | - results_ref = llm_ref.generate(prompts, sampling_params) |
101 | | - generated_text_ref = [result.outputs[0].text for result in results_ref] |
102 | | - llm_ref.shutdown() |
103 | | - |
104 | | - for text_spec, text_ref in zip(generated_text_spec, generated_text_ref): |
105 | | - # The spec decode algorithm currently guarantees identical results |
106 | | - assert text_spec == text_ref |
107 | | - finally: |
108 | | - # Restore original environment variable value |
109 | | - if original_value is None: |
110 | | - os.environ.pop("TLLM_WORKER_USE_SINGLE_PROCESS", None) |
111 | | - else: |
112 | | - os.environ["TLLM_WORKER_USE_SINGLE_PROCESS"] = original_value |
| 29 | + # so we use the enforce_single_worker fixture to set the environment variable. |
| 30 | + total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 |
| 31 | + if total_mem_gb < 35: |
| 32 | + pytest.skip("Not enough memory to load target + draft model") |
| 33 | + |
| 34 | + models_path = llm_models_root() |
| 35 | + eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B" |
| 36 | + target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct" |
| 37 | + |
| 38 | + max_batch_size = 1 |
| 39 | + max_draft_len = 4 |
| 40 | + kv_cache_config = KvCacheConfig(enable_block_reuse=True, max_tokens=8192) |
| 41 | + cuda_graph_config = CudaGraphConfig(batch_sizes=[1]) |
| 42 | + |
| 43 | + llm_common_config = dict( |
| 44 | + model=target_model_dir, |
| 45 | + attn_backend="TRTLLM", |
| 46 | + disable_overlap_scheduler=disable_overlap_scheduler, |
| 47 | + cuda_graph_config=cuda_graph_config, |
| 48 | + max_batch_size=max_batch_size, |
| 49 | + kv_cache_config=kv_cache_config, |
| 50 | + # This max_seq_len is larger than the one specified |
| 51 | + # in the llama 3 8B eagle's config. We want to make sure |
| 52 | + # that the draft model won't go above its max in warmup |
| 53 | + # in this test. |
| 54 | + max_seq_len=8192, |
| 55 | + ) |
| 56 | + |
| 57 | + spec_config = EagleDecodingConfig( |
| 58 | + max_draft_len=max_draft_len, |
| 59 | + speculative_model_dir=eagle_model_dir, |
| 60 | + # Llama 3 does not support one model eagle. |
| 61 | + eagle3_one_model=False, |
| 62 | + ) |
| 63 | + |
| 64 | + llm_spec = LLM(**llm_common_config, speculative_config=spec_config) |
| 65 | + sampling_params = SamplingParams(max_tokens=128, temperature=0) |
| 66 | + |
| 67 | + # Output tests |
| 68 | + prompts = [ |
| 69 | + "The president of the United States is", |
| 70 | + ] |
| 71 | + sampling_params = SamplingParams(max_tokens=20, temperature=0) |
| 72 | + |
| 73 | + # Mock should_use_spec_decode to turn on/off spec decode dynamically. |
| 74 | + def mock_should_use_spec_decode(requests, max_batch_size, max_num_tokens, |
| 75 | + max_draft_len): |
| 76 | + for req in requests: |
| 77 | + if req.state != LlmRequestState.GENERATION_IN_PROGRESS: |
| 78 | + continue |
| 79 | + |
| 80 | + mock_should_use_spec_decode.call_count += 1 |
| 81 | + # Turn off spec decode when we've called it 5 times. |
| 82 | + # In the current case, at the 5th call, there are 2 accepted draft tokens, |
| 83 | + # so we can have better coverage for the switching between spec decode on and off. |
| 84 | + if mock_should_use_spec_decode.call_count > 5: |
| 85 | + return False |
| 86 | + return True |
| 87 | + |
| 88 | + # Create a Mock object with the mock function as side_effect |
| 89 | + mock_should_use_spec_decode = Mock(side_effect=mock_should_use_spec_decode) |
| 90 | + # Reset mock state before using it |
| 91 | + mock_should_use_spec_decode.reset_mock() |
| 92 | + mock_should_use_spec_decode.call_count = 0 |
| 93 | + |
| 94 | + with patch( |
| 95 | + 'tensorrt_llm._torch.speculative.model_drafter.ModelDrafter.should_use_spec_decode', |
| 96 | + mock_should_use_spec_decode): |
| 97 | + results_spec = llm_spec.generate(prompts, sampling_params) |
| 98 | + generated_text_spec = [result.outputs[0].text for result in results_spec] |
| 99 | + llm_spec.shutdown() |
| 100 | + |
| 101 | + llm_ref = LLM(**llm_common_config) |
| 102 | + results_ref = llm_ref.generate(prompts, sampling_params) |
| 103 | + generated_text_ref = [result.outputs[0].text for result in results_ref] |
| 104 | + llm_ref.shutdown() |
| 105 | + |
| 106 | + for text_spec, text_ref in zip(generated_text_spec, generated_text_ref): |
| 107 | + # The spec decode algorithm currently guarantees identical results |
| 108 | + assert text_spec == text_ref |
113 | 109 |
|
114 | 110 |
|
115 | 111 | def test_should_use_spec_decode(): |
|
0 commit comments