@@ -339,40 +339,7 @@ def unpack_compare_nvfp4(
339339 return output_unpacked , output_ref
340340
341341
342- @pytest .mark .parametrize ("kv_layout" , ["HND" , "NHD" ])
343- @pytest .mark .parametrize (
344- "batch_size,page_size,num_kv_heads,head_grp_size" ,
345- [
346- (4 , 16 , 2 , 1 ),
347- (4 , 32 , 4 , 5 ),
348- (4 , 64 , 4 , 8 ),
349- (128 , 16 , 2 , 5 ),
350- (128 , 32 , 4 , 1 ),
351- (128 , 64 , 2 , 8 ),
352- (256 , 16 , 4 , 8 ),
353- (256 , 32 , 2 , 8 ),
354- (256 , 64 , 4 , 1 ),
355- (256 , 64 , 4 , 5 ),
356- ],
357- )
358- @pytest .mark .parametrize ("window_left" , [- 1 ]) # todo(Siyuan): add 127 window_left
359- @pytest .mark .parametrize (
360- "q_dtype,kv_dtype,o_dtype" ,
361- [
362- ("bf16" , "bf16" , "bf16" ),
363- ("fp16" , "fp16" , "fp16" ),
364- ("fp8" , "fp8" , "bf16" ),
365- ("fp8" , "fp8" , "fp16" ),
366- ("fp8" , "fp8" , "fp8" ),
367- ("fp8" , "fp8" , "nvfp4" ),
368- ],
369- )
370- @pytest .mark .parametrize ("enable_pdl" , [True , False , None ])
371- @pytest .mark .parametrize ("enable_sink" , [True , False ])
372- @pytest .mark .parametrize ("max_q_len" , [511 ])
373- @pytest .mark .parametrize ("max_kv_len" , [2047 ])
374- @pytest .mark .parametrize ("device_scale" , [True , False ])
375- def test_trtllm_batch_prefill (
342+ def _test_trtllm_batch_prefill (
376343 kv_layout ,
377344 batch_size ,
378345 page_size ,
@@ -580,6 +547,71 @@ def test_trtllm_batch_prefill(
580547 assert (workspace_buffer [: 8192 * 256 * 4 ].cpu ().numpy () == 0 ).all ()
581548
582549
550+ @pytest .mark .parametrize ("kv_layout" , ["HND" , "NHD" ])
551+ @pytest .mark .parametrize (
552+ "batch_size,page_size,num_kv_heads,head_grp_size" ,
553+ [
554+ (4 , 16 , 2 , 1 ),
555+ (4 , 32 , 4 , 5 ),
556+ (4 , 64 , 4 , 8 ),
557+ (128 , 16 , 2 , 5 ),
558+ (128 , 32 , 4 , 1 ),
559+ (128 , 64 , 2 , 8 ),
560+ (256 , 16 , 4 , 8 ),
561+ (256 , 32 , 2 , 8 ),
562+ (256 , 64 , 4 , 1 ),
563+ (256 , 64 , 4 , 5 ),
564+ ],
565+ )
566+ @pytest .mark .parametrize ("window_left" , [- 1 ]) # todo(Siyuan): add 127 window_left
567+ @pytest .mark .parametrize (
568+ "q_dtype,kv_dtype,o_dtype" ,
569+ [
570+ ("bf16" , "bf16" , "bf16" ),
571+ ("fp16" , "fp16" , "fp16" ),
572+ ("fp8" , "fp8" , "bf16" ),
573+ ("fp8" , "fp8" , "fp16" ),
574+ ("fp8" , "fp8" , "fp8" ),
575+ ("fp8" , "fp8" , "nvfp4" ),
576+ ],
577+ )
578+ @pytest .mark .parametrize ("enable_pdl" , [None ])
579+ @pytest .mark .parametrize ("enable_sink" , [True , False ])
580+ @pytest .mark .parametrize ("max_q_len" , [511 ])
581+ @pytest .mark .parametrize ("max_kv_len" , [2047 ])
582+ def test_trtllm_batch_prefill (
583+ kv_layout ,
584+ batch_size ,
585+ page_size ,
586+ num_kv_heads ,
587+ head_grp_size ,
588+ window_left ,
589+ q_dtype ,
590+ o_dtype ,
591+ kv_dtype ,
592+ enable_pdl ,
593+ enable_sink ,
594+ max_q_len ,
595+ max_kv_len ,
596+ ):
597+ _test_trtllm_batch_prefill (
598+ kv_layout ,
599+ batch_size ,
600+ page_size ,
601+ num_kv_heads ,
602+ head_grp_size ,
603+ window_left ,
604+ q_dtype ,
605+ o_dtype ,
606+ kv_dtype ,
607+ enable_pdl ,
608+ enable_sink ,
609+ max_q_len ,
610+ max_kv_len ,
611+ kv_dtype == "fp8" ,
612+ )
613+
614+
583615@pytest .mark .parametrize ("kv_layout" , ["HND" , "NHD" ])
584616@pytest .mark .parametrize (
585617 "batch_size,page_size,num_kv_heads,head_grp_size" ,
@@ -613,7 +645,7 @@ def test_trtllm_batch_prefill_bs1(
613645 max_q_len ,
614646 max_kv_len ,
615647):
616- test_trtllm_batch_prefill (
648+ _test_trtllm_batch_prefill (
617649 kv_layout ,
618650 batch_size ,
619651 page_size ,
@@ -966,7 +998,6 @@ def _test_trtllm_batch_decode(
966998@pytest .mark .parametrize ("enable_sink" , [True , False ])
967999@pytest .mark .parametrize ("max_in_kv_len" , [110 ])
9681000@pytest .mark .parametrize ("head_dim" , [128 ])
969- @pytest .mark .parametrize ("device_scale" , [True , False ])
9701001def test_trtllm_batch_decode (
9711002 backend ,
9721003 kv_layout ,
@@ -983,7 +1014,6 @@ def test_trtllm_batch_decode(
9831014 enable_sink ,
9841015 max_in_kv_len ,
9851016 head_dim ,
986- device_scale ,
9871017):
9881018 # General set of tests for trtllm-gen decode
9891019 _test_trtllm_batch_decode (
@@ -1002,7 +1032,7 @@ def test_trtllm_batch_decode(
10021032 enable_sink ,
10031033 max_in_kv_len ,
10041034 head_dim ,
1005- device_scale ,
1035+ kv_dtype == "fp8" ,
10061036 )
10071037
10081038
@@ -1024,6 +1054,7 @@ def test_trtllm_batch_decode(
10241054@pytest .mark .parametrize ("enable_sink" , [False ])
10251055@pytest .mark .parametrize ("max_in_kv_len" , [8192 ])
10261056@pytest .mark .parametrize ("head_dim" , [128 ])
1057+ @pytest .mark .parametrize ("device_scale" , [True , False ])
10271058def test_trtllm_batch_decode_bs1 (
10281059 kv_layout ,
10291060 batch_size ,
@@ -1039,6 +1070,7 @@ def test_trtllm_batch_decode_bs1(
10391070 enable_sink ,
10401071 max_in_kv_len ,
10411072 head_dim ,
1073+ device_scale ,
10421074):
10431075 # Small number of test cases for batch size 1
10441076 pytest .xfail ("trtllm-gen decode gets incorrect output with bs1" )
@@ -1058,7 +1090,7 @@ def test_trtllm_batch_decode_bs1(
10581090 enable_sink ,
10591091 max_in_kv_len ,
10601092 head_dim ,
1061- False ,
1093+ device_scale ,
10621094 )
10631095
10641096
@@ -1091,6 +1123,7 @@ def test_trtllm_batch_decode_bs1(
10911123@pytest .mark .parametrize ("enable_sink" , [False ])
10921124@pytest .mark .parametrize ("max_in_kv_len" , [110 ])
10931125@pytest .mark .parametrize ("head_dim" , [256 ])
1126+ @pytest .mark .parametrize ("device_scale" , [True , False ])
10941127def test_trtllm_batch_decode_head_dim_256 (
10951128 kv_layout ,
10961129 batch_size ,
@@ -1106,6 +1139,7 @@ def test_trtllm_batch_decode_head_dim_256(
11061139 enable_sink ,
11071140 max_in_kv_len ,
11081141 head_dim ,
1142+ device_scale ,
11091143):
11101144 # Small number of test cases for head_dim = 256
11111145 pytest .xfail ("trtllm-gen decode gets incorrect output with head_dim = 256" )
@@ -1125,7 +1159,7 @@ def test_trtllm_batch_decode_head_dim_256(
11251159 enable_sink ,
11261160 max_in_kv_len ,
11271161 head_dim ,
1128- True ,
1162+ device_scale ,
11291163 )
11301164
11311165
@@ -1151,6 +1185,7 @@ def test_trtllm_batch_decode_head_dim_256(
11511185@pytest .mark .parametrize ("enable_sink" , [False ])
11521186@pytest .mark .parametrize ("max_in_kv_len" , [4096 , 8192 , 16384 , 32768 , 65536 , 131072 ])
11531187@pytest .mark .parametrize ("head_dim" , [128 ])
1188+ @pytest .mark .parametrize ("device_scale" , [True , False ])
11541189def test_trtllm_batch_decode_long_sequence_length (
11551190 kv_layout ,
11561191 batch_size ,
@@ -1166,6 +1201,7 @@ def test_trtllm_batch_decode_long_sequence_length(
11661201 enable_sink ,
11671202 max_in_kv_len ,
11681203 head_dim ,
1204+ device_scale ,
11691205):
11701206 # Small number of test cases for long sequence length
11711207 _test_trtllm_batch_decode (
@@ -1184,7 +1220,7 @@ def test_trtllm_batch_decode_long_sequence_length(
11841220 enable_sink ,
11851221 max_in_kv_len ,
11861222 head_dim ,
1187- False ,
1223+ device_scale ,
11881224 )
11891225
11901226
0 commit comments