@@ -24,38 +24,40 @@ def enforce_single_worker(monkeypatch):
2424
2525
2626@pytest .mark .parametrize (
27- "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch" ,
27+ "use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp " ,
2828 [
29- [True , "TRTLLM" , True , False , False , False , True , False ],
30- [True , "TRTLLM" , True , False , False , False , False , False ],
31- [False , "TRTLLM" , True , False , False , False , True , False ],
32- [False , "TRTLLM" , True , False , False , False , False , False ],
33- [True , "FLASHINFER" , True , False , False , False , True , False ],
34- [False , "FLASHINFER" , True , False , False , False , True , False ],
35- [False , "TRTLLM" , False , True , True , False , True , False ],
36- [True , "TRTLLM" , False , True , True , False , True , False ],
37- [True , "TRTLLM" , True , False , True , True , True , False ],
38- [True , "TRTLLM" , True , False , True , False , True , False ],
29+ [True , "TRTLLM" , True , False , False , False , True , False , False ],
30+ [True , "TRTLLM" , True , False , False , False , False , False , False ],
31+ [False , "TRTLLM" , True , False , False , False , True , False , False ],
32+ [False , "TRTLLM" , True , False , False , False , False , False , False ],
33+ [True , "FLASHINFER" , True , False , False , False , True , False , False ],
34+ [False , "FLASHINFER" , True , False , False , False , True , False , False ],
35+ [False , "TRTLLM" , False , True , True , False , True , False , False ],
36+ [True , "TRTLLM" , False , True , True , False , True , False , False ],
37+ [True , "TRTLLM" , True , False , True , True , True , False , False ],
38+ [True , "TRTLLM" , True , False , True , False , True , False , False ],
3939 # TODO: nvbugs/5461761
4040 # [True, "TRTLLM", True, False, False, True, True, False],
41- [True , "TRTLLM" , False , False , False , False , True , False ],
42- [False , "TRTLLM" , False , False , False , False , True , False ],
43- [True , "TRTLLM" , False , False , False , False , False , True ],
44- [False , "TRTLLM" , False , False , False , False , False , True ],
45- [True , "TRTLLM" , False , False , False , False , True , True ],
46- [False , "TRTLLM" , False , False , False , False , True , True ],
47- [True , "TRTLLM" , False , False , False , False , False , False ],
48- [False , "TRTLLM" , False , False , False , False , False , False ],
49- [True , "TRTLLM" , False , False , False , True , True , False ],
50- [True , "TRTLLM" , False , False , False , True , False , False ],
51- [True , "FLASHINFER" , False , False , False , False , True , False ],
52- [False , "FLASHINFER" , False , False , False , False , True , False ],
41+ [True , "TRTLLM" , False , False , False , False , True , False , False ],
42+ [False , "TRTLLM" , False , False , False , False , True , False , False ],
43+ [True , "TRTLLM" , False , False , False , False , False , True , False ],
44+ [True , "TRTLLM" , False , False , False , False , False , True , True ],
45+ [False , "TRTLLM" , False , False , False , False , False , True , False ],
46+ [True , "TRTLLM" , False , False , False , False , True , True , False ],
47+ [False , "TRTLLM" , False , False , False , False , True , True , False ],
48+ [True , "TRTLLM" , False , False , False , False , False , False , False ],
49+ [False , "TRTLLM" , False , False , False , False , False , False , False ],
50+ [True , "TRTLLM" , False , False , False , True , True , False , False ],
51+ [True , "TRTLLM" , False , False , False , True , False , False , False ],
52+ [True , "FLASHINFER" , False , False , False , False , True , False , False ],
53+ [False , "FLASHINFER" , False , False , False , False , True , False , False ],
5354 ])
5455@pytest .mark .high_cuda_memory
5556def test_llama_eagle3 (use_cuda_graph : bool , attn_backend : str ,
5657 disable_overlap_scheduler : bool , enable_block_reuse : bool ,
5758 use_one_model : bool , enable_chunked_prefill : bool ,
58- use_chain_drafter : bool , multi_batch : bool , request ):
59+ use_chain_drafter : bool , multi_batch : bool ,
60+ attention_dp : bool , request ):
5961 # Use enforce_single_worker fixture only when use_chain_drafter is False.
6062 # Otherwise, we can't modify the returned value of _get_allow_chain_drafter in multiprocessing.
6163 if not use_chain_drafter :
@@ -98,6 +100,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
98100 cuda_graph_config = cuda_graph_config ,
99101 max_batch_size = max_batch_size ,
100102 kv_cache_config = kv_cache_config ,
103+ enable_attention_dp = attention_dp ,
101104 # This max_seq_len is larger than the one specified
102105 # in the llama 3 8B eagle's config. We want to make sure
103106 # that the draft model won't go above its max in warmup
0 commit comments