2020 sparse_cutlass_supported )
2121from vllm .platforms import current_platform
2222
23+ # AITER only supports per-channel-per-channel INT8 gemm
24+ # and per-tensor-per-tensor INT8 GEMM.
25+ # It does not support mix precision MM and mix quantization scheme.
26+ ROCM_AITER_SUPPORTED_INT8_MODEL = [
27+ "neuralmagic/Llama-3.2-1B-quantized.w8a8" ,
28+ "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2"
29+ ]
30+
31+ # TritonScaledMMLinearKernel only supports symmetric quantization.
32+ ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
33+ "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change" ,
34+ "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor" ,
35+ "neuralmagic/Llama-3.2-1B-quantized.w8a8" ,
36+ "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2" ,
37+ "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2" ,
38+ ]
39+
2340
2441@pytest .fixture (scope = "function" , autouse = True )
2542def use_v0_only (monkeypatch ):
@@ -57,6 +74,11 @@ def use_v0_only(monkeypatch):
5774)
5875def test_compressed_tensors_w8a8_static_setup (vllm_runner , model_args ):
5976 model_path , strategy , quant_type , shape_0 , is_symmetric = model_args
77+
78+ if current_platform .is_rocm (
79+ ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL :
80+ pytest .skip (f"Skip model { model_path } as it is not support on ROCm." )
81+
6082 with vllm_runner (model_path , enforce_eager = True ) as llm :
6183
6284 def check_model (model ):
@@ -123,14 +145,30 @@ def zp_valid(zp: Optional[torch.Tensor]):
123145)
124146@pytest .mark .parametrize ("max_tokens" , [32 ])
125147@pytest .mark .parametrize ("num_logprobs" , [10 ])
148+ @pytest .mark .parametrize (
149+ "use_aiter" , [True , False ] if current_platform .is_rocm () else [False ])
126150def test_compressed_tensors_w8a8_logprobs (
127151 hf_runner ,
128152 vllm_runner ,
129153 example_prompts ,
130154 model_path ,
131155 max_tokens ,
132156 num_logprobs ,
157+ use_aiter ,
158+ monkeypatch ,
133159):
160+
161+ if current_platform .is_rocm (
162+ ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL :
163+ pytest .skip (f"Skip model { model_path } as it is not support on ROCm." )
164+
165+ if use_aiter :
166+ if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL :
167+ pytest .skip (
168+ f"Skip model { model_path } as it is not support by aiter." )
169+ # this will enable VLLM_ROCM_USE_AITER_LINEAR
170+ monkeypatch .setenv ("VLLM_ROCM_USE_AITER" , "1" )
171+
134172 dtype = "bfloat16"
135173
136174 # skip language translation prompt for the static per tensor asym model
@@ -154,6 +192,9 @@ def test_compressed_tensors_w8a8_logprobs(
154192 name_1 = "vllm" ,
155193 )
156194
195+ if current_platform .is_rocm ():
196+ torch .cuda .synchronize ()
197+
157198
158199def test_compressed_tensors_no_enforce_eager (vllm_runner ):
159200 model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
@@ -177,8 +218,27 @@ def test_compressed_tensors_no_enforce_eager(vllm_runner):
177218 ),
178219 ],
179220)
180- def test_compressed_tensors_w8a8_dynamic_per_token (vllm_runner , model_args ):
221+ @pytest .mark .parametrize (
222+ "use_aiter" , [True , False ] if current_platform .is_rocm () else [False ])
223+ def test_compressed_tensors_w8a8_dynamic_per_token (
224+ vllm_runner ,
225+ model_args ,
226+ use_aiter ,
227+ monkeypatch ,
228+ ):
181229 model_path , strategy = model_args
230+
231+ if current_platform .is_rocm (
232+ ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL :
233+ pytest .skip (f"Skip model { model_path } as it is not support on ROCm." )
234+
235+ if use_aiter :
236+ if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL :
237+ pytest .skip (
238+ f"Skip model { model_path } as it is not support by aiter." )
239+ # this will enable VLLM_ROCM_USE_AITER_LINEAR
240+ monkeypatch .setenv ("VLLM_ROCM_USE_AITER" , "1" )
241+
182242 with vllm_runner (model_path , dtype = torch .float16 ) as llm :
183243
184244 def check_model (model ):
@@ -207,6 +267,8 @@ def check_model(model):
207267 ("nm-testing/tinyllama-oneshot-w8a16-per-channel" , "channel" , None , 4 ),
208268 ],
209269)
270+ @pytest .mark .skipif (not current_platform .is_cuda (),
271+ reason = "The tests are skipped on non-CUDA platform." )
210272def test_compressed_tensors_wNa16 (vllm_runner , wNa16_args ):
211273 model , strategy , group , pack_factor = wNa16_args
212274 with vllm_runner (model ) as llm :
@@ -231,6 +293,8 @@ def check_model(model):
231293 assert output
232294
233295
296+ @pytest .mark .skipif (not current_platform .is_cuda (),
297+ reason = "This test is skipped on non-CUDA platform." )
234298def test_compressed_tensors_w4a16_marlin24 (vllm_runner ):
235299 model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
236300 with vllm_runner (model_path ) as llm :
@@ -271,7 +335,7 @@ def check_model(model):
271335
272336 if isinstance (qkv_proj .scheme , CompressedTensorsW8A8Fp8 ):
273337 assert len (qkv_proj .input_scale .shape ) == 0
274- assert qkv_proj .weight .dtype is torch . float8_e4m3fn
338+ assert qkv_proj .weight .dtype is current_platform . fp8_dtype ()
275339 assert qkv_proj .weight_scale .dtype is torch .float32
276340 assert len (qkv_proj .weight_scale .shape ) == 0
277341
@@ -281,6 +345,8 @@ def check_model(model):
281345 assert output
282346
283347
348+ @pytest .mark .skipif (not current_platform .is_cuda (),
349+ reason = "This test is skipped on non-CUDA platform." )
284350def test_compressed_tensors_kv_cache (vllm_runner ):
285351 model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
286352 with vllm_runner (model_path , kv_cache_dtype = "fp8" ) as llm :
@@ -309,7 +375,8 @@ def _test_2of4_quant_models(qkv_proj,
309375
310376
311377@pytest .mark .skipif (
312- not current_platform .has_device_capability (90 ),
378+ not current_platform .is_cuda ()
379+ or not current_platform .has_device_capability (90 ),
313380 reason = "Sparse FP8 is not yet supported on this GPU type." ,
314381)
315382@pytest .mark .parametrize (
@@ -356,7 +423,8 @@ def check_model(model):
356423
357424
358425@pytest .mark .skipif (
359- not current_platform .has_device_capability (90 ),
426+ not current_platform .is_cuda ()
427+ or not current_platform .has_device_capability (90 ),
360428 reason = "Sparse FP8 is not yet supported on this GPU type." ,
361429)
362430@pytest .mark .parametrize (
0 commit comments