Skip to content
4 changes: 2 additions & 2 deletions tests/quantization/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,13 @@ class ModelPair:
(
"TheBloke/OpenHermes-2.5-Mistral-7B-AWQ",
None,
"awq_marlin" if current_platform.is_cuda() else "awq",
"awq_marlin" if current_platform.is_cuda_alike() else "awq",
),
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "awq", "awq"),
(
"TheBloke/OpenHermes-2.5-Mistral-7B-AWQ",
"marlin",
"awq_marlin" if current_platform.is_cuda() else "ERROR",
"awq_marlin" if current_platform.is_cuda_alike() else "ERROR",
),
("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "gptq", "ERROR"),
]
Expand Down
2 changes: 1 addition & 1 deletion tests/quantization/test_mixed_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def get_model_args(self) -> str:
"amd/Qwen3-8B-WMXFP4FP8-AMXFP4FP8-AMP-KVFP8": {"arc_challenge": 0.52, "mmlu": 0.72},
# Non-mixed-precision (PTQ) model
# - Reference for pipeline compatibility verification -> No conflicts or breakings
"amd/Llama-2-70b-chat-hf-FP8-MLPerf-fp8_attn_quark_format": {
"amd/Llama-2-70b-chat-hf_FP8_MLPerf_V2": {
"arc_challenge": 0.53,
"mmlu": 0.61,
},
Expand Down
5 changes: 4 additions & 1 deletion tests/quantization/test_turboquant.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,10 @@ def generate_rotation_matrix(d: int, seed: int, device: str = "cpu") -> torch.Te
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
G = torch.randn(d, d, generator=gen, device="cpu", dtype=torch.float32)
Q, R = torch.linalg.qr(G)
# torch.linalg.qr on CPU requires LAPACK, which some torch wheels
# (ROCm) ship without. Run QR on accelerator instead
qr_device = "cuda" if torch.cuda.is_available() else "cpu"
Q, R = torch.linalg.qr(G.to(qr_device))
diag_sign = torch.sign(torch.diag(R))
diag_sign[diag_sign == 0] = 1.0
Q = Q * diag_sign.unsqueeze(0)
Expand Down
9 changes: 4 additions & 5 deletions tests/quantization/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,13 @@ def _test_online_quant_peak_mem_impl(
print(f"GPU memory used after loading weights: {model_memory_gib} GiB")
print(f"Peak GPU memory usage while loading weights: {peak_memory_gib} GiB")

# model specific, allenai/OLMoE-1B-7B-0125-Instruct fp8 online quant
# uses 6.65 GiB for weight loading (bf16 checkpoint is ~12.89 GiB)
expected_model_memory_gib = 6.7

# for allenai/OLMoE-1B-7B-0125-Instruct the number we see today is 9.06
# GiB, which is 1.36x above model_memory_gib. A slightly higher number is
# expected as when we load and quantize weights in a streaming fashion we
# need to have individual weights in bf16 + fp8 alive at the same time.
# GiB on CUDA, which is 1.36x above model_memory_gib. A slightly higher
# number is expected as when we load and quantize weights in a streaming
# fashion we need to have individual weights in bf16 + fp8 alive at the
# same time.
expected_peak_memory_gib = expected_model_memory_gib * 1.4

assert model_memory_gib < expected_model_memory_gib, (
Expand Down
10 changes: 9 additions & 1 deletion vllm/model_executor/kernels/linear/scaled_mm/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,21 @@ def apply_scaled_mm(
# For CUDA platform please validate if the torch._scaled_mm supports
# rowwise scaled GEMM before using it

# torch._scaled_mm rowwise requires scale_a = (m, 1), scale_b = (1, n).
# CompressedTensors stores weight_scale as (n, 1), so `.t()` yields (1, n).
# ModelOpt FP8_PER_CHANNEL_PER_TOKEN stores it as 1-D (n,); reshape to
# (1, n) so both paths satisfy the rowwise contract.
scale_b = Bs.view(1, -1) if Bs.dim() == 1 else Bs.t()
if As.dim() == 1:
As = As.view(-1, 1)

# Fused GEMM_DQ Rowwise GEMM
output = torch._scaled_mm(
A,
B,
out_dtype=out_dtype,
scale_a=As,
scale_b=Bs.t(),
scale_b=scale_b,
bias=bias,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1025,9 +1025,15 @@ def __init__(
get_current_vllm_config().model_config.hf_config, "model_type", None
)

# TODO(aiter): extend once rocm_aiter_fused_experts gains dispatch
# for the other OCP MX schemes. Today its CK MoE kernel only has an
# entry for `w_mxfp4` (w4a16); mixed schemes like `w_mxfp4_a_mxfp6_*`
# fall through to QuantMethod.NO and raise "Unsupported kernel config
# for moe heuristic dispatch".
_AITER_NATIVE_OCP_MX_SCHEMES = ("w_mxfp4",)
self.emulate = (
not current_platform.supports_mx()
or not self.ocp_mx_scheme.startswith("w_mxfp4")
or self.ocp_mx_scheme not in _AITER_NATIVE_OCP_MX_SCHEMES
) and (
self.mxfp4_backend is Mxfp4MoeBackend.NONE or not self.use_rocm_aiter_moe
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,15 @@ def apply_weights(
dq_w = self.dequant_func(layer.weight, layer.weight_scale, x.dtype)
qdq_x = self.quant_dequant_func(x)
return F.linear(qdq_x, dq_w, bias)
else:
return torch.ops.vllm.gemm_with_dynamic_quant(
x,
layer.weight,
layer.weight_scale,
self.rocm_use_aiter_fp4_asm_gemm,
self.out_dtype,
)
y = torch.ops.vllm.gemm_with_dynamic_quant(
x,
layer.weight,
layer.weight_scale,
self.rocm_use_aiter_fp4_asm_gemm,
self.out_dtype,
)
# gemm_with_dynamic_quant has no bias argument; add it here so the
# native path matches F.linear (e.g. qkv_proj with qkv_bias=True).
if bias is not None:
y = y + bias
return y
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/base_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def load_model(

# Log peak GPU memory after loading weights. This is needed
# to have test coverage on peak memory for online quantization.
if current_platform.is_cuda():
if current_platform.is_cuda_alike():
peak_memory = torch.accelerator.max_memory_allocated()
logger.debug_once(
"Peak GPU memory after loading weights: %s GiB",
Expand Down
13 changes: 10 additions & 3 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,17 @@ class RocmPlatform(Platform):
"gguf",
"quark",
"mxfp4",
"gpt_oss_mxfp4",
"mxfp8",
"torchao",
"bitsandbytes",
"modelopt",
"modelopt_fp4",
"modelopt_mxfp8",
"modelopt_mixed",
"fp8_per_tensor",
"fp8_per_block",
"online",
"gpt_oss_mxfp4",
]

@classmethod
Expand Down Expand Up @@ -785,9 +792,9 @@ def get_punica_wrapper(cls) -> str:
def get_current_memory_usage(
cls, device: torch.types.Device | None = None
) -> float:
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
free_mem, total_mem = torch.cuda.mem_get_info(device)
return total_mem - free_mem
return torch.cuda.max_memory_allocated(device)

@classmethod
def get_device_communicator_cls(cls) -> str:
Expand Down
Loading