99workspace_size = 128 * 1024 * 1024
1010
1111
12- @pytest .mark .parametrize (
13- "batch_size" ,
14- [1 , 2 , 4 , 16 , 32 , 64 , 128 , 256 , 512 , 768 , 1024 ],
15- )
16- @pytest .mark .parametrize ("scale" , [1.0 , 0.5 ])
17- @pytest .mark .parametrize ("dtype" , [torch .float8_e4m3fn , torch .bfloat16 ])
18- @pytest .mark .parametrize ("page_size" , [32 , 64 ])
19- @pytest .mark .parametrize (
20- "q_len_per_request" , [1 , 2 ]
21- ) # todo(Yingyi): verify larger q_len_per_request
22- @pytest .mark .parametrize ("dynamic_scale" , [False ])
23- @pytest .mark .parametrize ("enable_pdl" , [True , False , None ])
24- @pytest .mark .parametrize ("backend" , ["trtllm-gen" , "xqa" ])
25- def test_trtllm_batch_decode_mla (
12+ def trtllm_batch_decode_mla (
2613 batch_size : int ,
2714 scale : float ,
2815 dtype : torch .dtype ,
@@ -31,6 +18,7 @@ def test_trtllm_batch_decode_mla(
3118 dynamic_scale : bool ,
3219 enable_pdl : bool ,
3320 backend : str ,
21+ MAX_SEQ_LEN : int ,
3422):
3523 compute_capability = get_compute_capability (torch .device (device = "cuda" ))
3624 if backend == "xqa" :
@@ -49,9 +37,6 @@ def test_trtllm_batch_decode_mla(
4937 torch .manual_seed (42 )
5038 device = "cuda:0"
5139
52- # Fixed max sequence length
53- MAX_SEQ_LEN = 1024
54-
5540 # Deepseek attention config (decode-MLA)
5641 num_q_heads = 128
5742 qk_nope_head_dim = 128
@@ -239,3 +224,75 @@ def test_trtllm_batch_decode_mla(
239224 f"Total { o_ref .numel ()} elements, only { pass_ratio :.1%} meet tolerance criteria, "
240225 f"require at least { required_ratio :.1%} "
241226 )
227+
228+
229+ @pytest .mark .parametrize (
230+ "batch_size" ,
231+ [1 , 2 , 4 , 16 , 32 , 64 , 128 , 256 , 512 , 768 , 1024 ],
232+ )
233+ @pytest .mark .parametrize ("scale" , [1.0 , 0.5 ])
234+ @pytest .mark .parametrize ("dtype" , [torch .float8_e4m3fn , torch .bfloat16 ])
235+ @pytest .mark .parametrize ("page_size" , [32 , 64 ])
236+ @pytest .mark .parametrize (
237+ "q_len_per_request" , [1 , 2 ]
238+ ) # todo(Yingyi): verify larger q_len_per_request
239+ @pytest .mark .parametrize ("dynamic_scale" , [False ])
240+ @pytest .mark .parametrize ("enable_pdl" , [True , False , None ])
241+ @pytest .mark .parametrize ("backend" , ["trtllm-gen" , "xqa" ])
242+ def test_trtllm_batch_decode_mla (
243+ batch_size : int ,
244+ scale : float ,
245+ dtype : torch .dtype ,
246+ page_size : int ,
247+ q_len_per_request : int ,
248+ dynamic_scale : bool ,
249+ enable_pdl : bool ,
250+ backend : str ,
251+ ):
252+ trtllm_batch_decode_mla (
253+ batch_size ,
254+ scale ,
255+ dtype ,
256+ page_size ,
257+ q_len_per_request ,
258+ dynamic_scale ,
259+ enable_pdl ,
260+ backend ,
261+ 1024 ,
262+ )
263+
264+
265+ @pytest .mark .parametrize (
266+ "batch_size" ,
267+ [2 , 4 , 8 ],
268+ )
269+ @pytest .mark .parametrize ("scale" , [1.0 , 0.5 ])
270+ @pytest .mark .parametrize ("dtype" , [torch .float8_e4m3fn , torch .bfloat16 ])
271+ @pytest .mark .parametrize ("page_size" , [64 ])
272+ @pytest .mark .parametrize ("q_len_per_request" , [1 , 2 , 3 ])
273+ @pytest .mark .parametrize ("dynamic_scale" , [False ])
274+ @pytest .mark .parametrize ("enable_pdl" , [True , False , None ])
275+ @pytest .mark .parametrize ("backend" , ["trtllm-gen" ])
276+ @pytest .mark .parametrize ("MAX_SEQ_LEN" , [1024 , 8960 ])
277+ def test_dsr1_trtllm_mla (
278+ batch_size : int ,
279+ scale : float ,
280+ dtype : torch .dtype ,
281+ page_size : int ,
282+ q_len_per_request : int ,
283+ dynamic_scale : bool ,
284+ enable_pdl : bool ,
285+ backend : str ,
286+ MAX_SEQ_LEN : int ,
287+ ):
288+ trtllm_batch_decode_mla (
289+ batch_size ,
290+ scale ,
291+ dtype ,
292+ page_size ,
293+ q_len_per_request ,
294+ dynamic_scale ,
295+ enable_pdl ,
296+ backend ,
297+ MAX_SEQ_LEN ,
298+ )
0 commit comments