1919from  tensorrt_llm .quantization  import  QuantAlgo 
2020
2121from  ..conftest  import  (llm_models_root , parametrize_with_ids , skip_no_nvls ,
22-                         skip_pre_ada , skip_pre_blackwell , skip_pre_hopper )
22+                         skip_post_blackwell , skip_pre_ada , skip_pre_blackwell ,
23+                         skip_pre_hopper )
2324from  .accuracy_core  import  (MMLU , CliFlowAccuracyTestHarness , CnnDailymail ,
2425                            Humaneval , PassKeyRetrieval64k ,
2526                            PassKeyRetrieval128k , SlimPajama6B , ZeroScrolls )
@@ -57,6 +58,7 @@ def test_weight_only(self, precision: str):
5758    def  test_int8_kv_cache (self ):
5859        self .run (kv_cache_quant_algo = QuantAlgo .INT8 )
5960
61+     @skip_post_blackwell  
6062    @parametrize_with_ids ("per_token,per_channel" , [(False , False ), 
6163                                                    (True , True )]) 
6264    def  test_smooth_quant (self , per_token : bool , per_channel : bool ):
@@ -142,6 +144,7 @@ class TestStarcoder2_15B(CliFlowAccuracyTestHarness):
142144    MODEL_PATH  =  f"{ llm_models_root ()}  
143145    EXAMPLE_FOLDER  =  "gpt" 
144146
147+     @skip_post_blackwell  
145148    def  test_smooth_quant_ootb (self ):
146149        self .run (tasks = [Humaneval (self .MODEL_NAME )],
147150                 quant_algo = QuantAlgo .W8A8_SQ_PER_CHANNEL )
@@ -194,9 +197,11 @@ class TestPhi2(CliFlowAccuracyTestHarness):
194197    MODEL_PATH  =  f"{ llm_models_root ()}  
195198    EXAMPLE_FOLDER  =  "phi" 
196199
200+     @skip_post_blackwell  
197201    def  test_auto_dtype (self ):
198202        self .run (dtype = 'auto' )
199203
204+     @skip_post_blackwell  
200205    @pytest .mark .skip_less_device (2 ) 
201206    def  test_tp2 (self ):
202207        self .run (tp_size = 2 )
@@ -316,6 +321,7 @@ def test_medusa(self, cuda_graph, mocker):
316321                 extra_build_args = ["--speculative_decoding_mode=medusa" ],
317322                 extra_summarize_args = extra_summarize_args )
318323
324+     @skip_post_blackwell  
319325    @parametrize_with_ids ("cuda_graph,chunked_context,typical_acceptance" , 
320326                          [(False , False , False ), (True , False , False ), 
321327                           (True , True , False ), (True , False , True )]) 
@@ -360,6 +366,7 @@ def test_beam_search(self):
360366                 extra_build_args = ["--max_beam_width=5" ],
361367                 extra_summarize_args = ["--num_beams=5" ])
362368
369+     @skip_post_blackwell  
363370    def  test_int4_gptq (self ):
364371        self .run (
365372            quant_algo = QuantAlgo .W4A16_GPTQ ,
@@ -386,6 +393,7 @@ class TestLlama2_7B(CliFlowAccuracyTestHarness):
386393    def  test_auto_dtype (self ):
387394        self .run (dtype = 'auto' )
388395
396+     @skip_post_blackwell  
389397    def  test_smooth_quant (self ):
390398        self .run (quant_algo = QuantAlgo .W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN )
391399
@@ -433,21 +441,25 @@ def test_fp8_low_latency_gemm_plugin(self):
433441                 extra_build_args = ["--low_latency_gemm_plugin=fp8" ])
434442
435443    @pytest .mark .skip_less_device (2 ) 
444+     @skip_post_blackwell  
436445    def  test_smooth_quant_ootb_tp2 (self ):
437446        self .run (quant_algo = QuantAlgo .W8A8_SQ_PER_CHANNEL , tp_size = 2 )
438447
439448    @pytest .mark .skip_less_device (2 ) 
449+     @skip_post_blackwell  
440450    def  test_int4_awq_tp2 (self ):
441451        self .run (quant_algo = QuantAlgo .W4A16_AWQ , tp_size = 2 )
442452
443453    @pytest .mark .skip_less_device (2 ) 
454+     @skip_post_blackwell  
444455    def  test_int4_awq_prequantized_tp2 (self , mocker ):
445456        mocker .patch .object (
446457            self .__class__ , "MODEL_PATH" ,
447458            f"{ llm_models_root ()}  )
448459        self .run (quant_algo = QuantAlgo .W4A16_AWQ , tp_size = 2 )
449460
450461    @pytest .mark .skip_less_device (2 ) 
462+     @skip_post_blackwell  
451463    def  test_int4_gptq_prequantized_tp2 (self , mocker ):
452464        mocker .patch .object (
453465            self .__class__ , "MODEL_PATH" ,
@@ -469,16 +481,19 @@ def test_auto_dtype(self):
469481    def  test_float32 (self ):
470482        self .run (dtype = 'float32' )
471483
484+     @skip_post_blackwell  
472485    @pytest .mark .parametrize ("precision" , ["int8" , "int4" ]) 
473486    def  test_weight_only (self , precision : str ):
474487        quant_algo  =  QuantAlgo .W8A16  if  precision  ==  "int8"  else  QuantAlgo .W4A16 
475488        self .run (quant_algo = quant_algo )
476489
490+     @skip_post_blackwell  
477491    @pytest .mark .parametrize ("precision" , ["int8" , "int4" ]) 
478492    def  test_weight_only_int8_kv_cache (self , precision : str ):
479493        quant_algo  =  QuantAlgo .W8A16  if  precision  ==  "int8"  else  QuantAlgo .W4A16 
480494        self .run (quant_algo = quant_algo , kv_cache_quant_algo = QuantAlgo .INT8 )
481495
496+     @skip_post_blackwell  
482497    @pytest .mark .parametrize ("precision" , ["int8" , "int4" ]) 
483498    def  test_weight_only_manage_weights (self , precision : str ):
484499        quant_algo  =  QuantAlgo .W8A16  if  precision  ==  "int8"  else  QuantAlgo .W4A16 
@@ -567,6 +582,7 @@ class TestLlama3_1_8B(CliFlowAccuracyTestHarness):
567582    def  test_auto_dtype (self ):
568583        self .run (dtype = 'auto' )
569584
585+     @skip_post_blackwell  
570586    def  test_smooth_quant (self ):
571587        self .run (quant_algo = QuantAlgo .W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN )
572588
@@ -575,12 +591,14 @@ def test_fp8(self):
575591        self .run (quant_algo = QuantAlgo .FP8 , kv_cache_quant_algo = QuantAlgo .FP8 )
576592
577593    @skip_pre_ada  
594+     @skip_post_blackwell  
578595    def  test_fp8_rowwise (self ):
579596        self .run (tasks = [CnnDailymail (self .MODEL_NAME ),
580597                        MMLU (self .MODEL_NAME )],
581598                 quant_algo = QuantAlgo .FP8_PER_CHANNEL_PER_TOKEN )
582599
583600    @skip_pre_ada  
601+     @skip_post_blackwell  
584602    def  test_fp8_rowwise_meta_recipe (self ):
585603        self .run (quant_algo = QuantAlgo .FP8_PER_CHANNEL_PER_TOKEN ,
586604                 extra_acc_spec = "meta_recipe" ,
@@ -601,6 +619,7 @@ def test_tp4(self, gemm_allreduce: bool):
601619            extra_build_args = extra_build_args )
602620
603621    @skip_pre_ada  
622+     @skip_post_blackwell  
604623    @pytest .mark .skip_less_device (4 ) 
605624    @pytest .mark .parametrize ( 
606625        "gemm_allreduce" , [False , pytest .param (True , marks = skip_no_nvls )], 
@@ -646,6 +665,7 @@ def test_fp8_prequantized(self, mocker):
646665        self .run (quant_algo = QuantAlgo .FP8 , kv_cache_quant_algo = QuantAlgo .FP8 )
647666
648667    @skip_pre_ada  
668+     @skip_post_blackwell  
649669    def  test_medusa_fp8_prequantized (self , mocker ):
650670        # nvidia/Llama-3.1-8B-Medusa-FP8 
651671        mocker .patch .object (self .__class__ , "MODEL_PATH" ,
@@ -670,23 +690,29 @@ class TestLlama3_2_1B(CliFlowAccuracyTestHarness):
670690    def  test_auto_dtype (self ):
671691        self .run (dtype = 'auto' )
672692
693+     @skip_post_blackwell  
673694    def  test_smooth_quant (self ):
674695        self .run (quant_algo = QuantAlgo .W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN )
675696
697+     @skip_post_blackwell  
676698    def  test_smooth_quant_ootb (self ):
677699        self .run (quant_algo = QuantAlgo .W8A8_SQ_PER_CHANNEL )
678700
701+     @skip_post_blackwell  
679702    def  test_smooth_quant_ootb_manage_weights (self ):
680703        self .run (quant_algo = QuantAlgo .W8A8_SQ_PER_CHANNEL ,
681704                 extra_build_args = ["--fast_build" ])
682705
706+     @skip_post_blackwell  
683707    def  test_int4_awq (self ):
684708        self .run (quant_algo = QuantAlgo .W4A16_AWQ )
685709
710+     @skip_post_blackwell  
686711    def  test_int4_awq_int8_kv_cache (self ):
687712        self .run (quant_algo = QuantAlgo .W4A16_AWQ ,
688713                 kv_cache_quant_algo = QuantAlgo .INT8 )
689714
715+     @skip_post_blackwell  
690716    def  test_int4_awq_manage_weights (self ):
691717        self .run (quant_algo = QuantAlgo .W4A16_AWQ ,
692718                 extra_build_args = ["--fast_build" ])
@@ -733,10 +759,12 @@ def test_fp8_pp2(self):
733759                 pp_size = 2 )
734760
735761    @skip_pre_ada  
762+     @skip_post_blackwell  
736763    def  test_fp8_rowwise (self ):
737764        self .run (quant_algo = QuantAlgo .FP8_PER_CHANNEL_PER_TOKEN )
738765
739766    @skip_pre_ada  
767+     @skip_post_blackwell  
740768    def  test_fp8_rowwise_meta_recipe (self ):
741769        self .run (quant_algo = QuantAlgo .FP8_PER_CHANNEL_PER_TOKEN ,
742770                 extra_acc_spec = "meta_recipe" ,
@@ -830,6 +858,7 @@ def test_weight_only(self, precision: str):
830858        quant_algo  =  QuantAlgo .W8A16  if  precision  ==  "int8"  else  QuantAlgo .W4A16 
831859        self .run (quant_algo = quant_algo , extra_convert_args = ["--ckpt-type=hf" ])
832860
861+     @skip_post_blackwell  
833862    def  test_smooth_quant (self ):
834863        self .run (quant_algo = QuantAlgo .W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN ,
835864                 extra_convert_args = [
@@ -841,6 +870,7 @@ def test_smooth_quant(self):
841870    def  test_fp8 (self ):
842871        self .run (quant_algo = QuantAlgo .FP8 , kv_cache_quant_algo = QuantAlgo .FP8 )
843872
873+     @skip_post_blackwell  
844874    def  test_int4_awq (self ):
845875        self .run (quant_algo = QuantAlgo .W4A16_AWQ )
846876
@@ -859,6 +889,7 @@ def test_weight_only(self, precision: str):
859889        quant_algo  =  QuantAlgo .W8A16  if  precision  ==  "int8"  else  QuantAlgo .W4A16 
860890        self .run (quant_algo = quant_algo , extra_convert_args = ["--ckpt-type=hf" ])
861891
892+     @skip_post_blackwell  
862893    @pytest .mark .skip_less_device_memory (50000 ) 
863894    def  test_smooth_quant (self ):
864895        self .run (quant_algo = QuantAlgo .W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN ,
@@ -871,6 +902,7 @@ def test_smooth_quant(self):
871902    def  test_fp8 (self ):
872903        self .run (quant_algo = QuantAlgo .FP8 , kv_cache_quant_algo = QuantAlgo .FP8 )
873904
905+     @skip_post_blackwell  
874906    def  test_int4_awq (self ):
875907        self .run (quant_algo = QuantAlgo .W4A16_AWQ )
876908
@@ -887,6 +919,7 @@ def test_auto_dtype(self):
887919                 dtype = 'auto' ,
888920                 extra_convert_args = ["--ckpt-type=hf" ])
889921
922+     @skip_post_blackwell  
890923    @pytest .mark .parametrize ("precision" , ["int8" , "int4" ]) 
891924    def  test_weight_only (self , precision : str ):
892925        quant_algo  =  QuantAlgo .W8A16  if  precision  ==  "int8"  else  QuantAlgo .W4A16 
@@ -910,6 +943,7 @@ def test_auto_dtype(self):
910943    def  test_weight_only (self ):
911944        self .run (quant_algo = QuantAlgo .W8A16 )
912945
946+     @skip_post_blackwell  
913947    def  test_int4_gptq_prequantized (self , mocker ):
914948        mocker .patch .object (self .__class__ , "MODEL_PATH" ,
915949                            f"{ llm_models_root ()}  )
@@ -938,6 +972,7 @@ class TestQwen2_0_5BInstruct(CliFlowAccuracyTestHarness):
938972    def  test_auto_dtype (self ):
939973        self .run (dtype = 'auto' )
940974
975+     @skip_post_blackwell  
941976    def  test_weight_only (self ):
942977        self .run (quant_algo = QuantAlgo .W8A16 )
943978
@@ -956,9 +991,11 @@ class TestQwen2_7BInstruct(CliFlowAccuracyTestHarness):
956991    def  test_auto_dtype (self ):
957992        self .run (dtype = 'auto' )
958993
994+     @skip_post_blackwell  
959995    def  test_weight_only (self ):
960996        self .run (quant_algo = QuantAlgo .W8A16 )
961997
998+     @skip_post_blackwell  
962999    def  test_int4_awq_prequantized (self , mocker ):
9631000        mocker .patch .object (self .__class__ , "MODEL_PATH" ,
9641001                            f"{ llm_models_root ()}  )
@@ -990,6 +1027,7 @@ class TestQwen2_5_1_5BInstruct(CliFlowAccuracyTestHarness):
9901027    def  test_auto_dtype (self ):
9911028        self .run (dtype = 'auto' )
9921029
1030+     @skip_post_blackwell  
9931031    def  test_weight_only (self ):
9941032        self .run (quant_algo = QuantAlgo .W8A16 )
9951033
0 commit comments