From a864404051ef36ab003402be3c07240929ab4d14 Mon Sep 17 00:00:00 2001 From: hanlin12 Date: Mon, 16 Mar 2026 06:57:58 +0000 Subject: [PATCH 01/19] Add the option to turn on hipBLASLt online tuning Signed-off-by: hanlin12 --- vllm/entrypoints/cli/serve.py | 4 ++++ vllm/entrypoints/openai/cli_args.py | 6 ++++++ 2 files changed, 10 insertions(+) diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index b0b5e7c206fc..4bb186ab7c01 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse +import os import signal import uvloop @@ -45,6 +46,9 @@ class ServeSubcommand(CLISubcommand): @staticmethod def cmd(args: argparse.Namespace) -> None: + if getattr(args, "hip_online_tuning", False): + os.environ["HIP_ONLINE_TUNING"] = "1" + # If model is specified in CLI (as positional arg), it takes precedence if hasattr(args, "model_tag") and args.model_tag is not None: args.model = args.model_tag diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index ab28b62999d8..dc8c1b94ad6b 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -348,6 +348,12 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "Must be a YAML with the following options: " "https://docs.vllm.ai/en/latest/configuration/serve_args.html", ) + parser.add_argument( + "--hip_online_tuning", + action="store_true", + default=False, + help="Enable AITER hipBLASLt online GEMM tuning. Disabled by default.", + ) parser = FrontendArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser) From 96cf2d10b6f1ff95524d0551ea545bab3dfc53b0 Mon Sep 17 00:00:00 2001 From: hanlin12 Date: Tue, 17 Mar 2026 01:42:07 +0000 Subject: [PATCH 02/19] move hip_online_tuning option into serve.py Signed-off-by: hanlin12 --- vllm/entrypoints/cli/serve.py | 6 ++++++ vllm/entrypoints/openai/cli_args.py | 6 ------ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 4bb186ab7c01..de5f73d207df 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -142,6 +142,12 @@ def subparser_init( help="Launch a gRPC server instead of the HTTP OpenAI-compatible " "server. Requires: pip install vllm[grpc].", ) + serve_parser.add_argument( + "--hip_online_tuning", + action="store_true", + default=False, + help="Enable AITER hipBLASLt online GEMM tuning. Disabled by default.", + ) serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name) return serve_parser diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index dc8c1b94ad6b..ab28b62999d8 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -348,12 +348,6 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "Must be a YAML with the following options: " "https://docs.vllm.ai/en/latest/configuration/serve_args.html", ) - parser.add_argument( - "--hip_online_tuning", - action="store_true", - default=False, - help="Enable AITER hipBLASLt online GEMM tuning. Disabled by default.", - ) parser = FrontendArgs.add_cli_args(parser) parser = AsyncEngineArgs.add_cli_args(parser) From d5753b3b761a444dc57367055e2722b28292f476 Mon Sep 17 00:00:00 2001 From: hanlin12 Date: Tue, 17 Mar 2026 09:44:18 +0000 Subject: [PATCH 03/19] use environment variable instead of CLI Signed-off-by: hanlin12 --- vllm/entrypoints/cli/serve.py | 10 ---------- vllm/envs.py | 6 ++++++ vllm/platforms/rocm.py | 4 ++++ 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index de5f73d207df..b0b5e7c206fc 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import argparse -import os import signal import uvloop @@ -46,9 +45,6 @@ class ServeSubcommand(CLISubcommand): @staticmethod def cmd(args: argparse.Namespace) -> None: - if getattr(args, "hip_online_tuning", False): - os.environ["HIP_ONLINE_TUNING"] = "1" - # If model is specified in CLI (as positional arg), it takes precedence if hasattr(args, "model_tag") and args.model_tag is not None: args.model = args.model_tag @@ -142,12 +138,6 @@ def subparser_init( help="Launch a gRPC server instead of the HTTP OpenAI-compatible " "server. Requires: pip install vllm[grpc].", ) - serve_parser.add_argument( - "--hip_online_tuning", - action="store_true", - default=False, - help="Enable AITER hipBLASLt online GEMM tuning. Disabled by default.", - ) serve_parser.epilog = VLLM_SUBCMD_PARSER_EPILOG.format(subcmd=self.name) return serve_parser diff --git a/vllm/envs.py b/vllm/envs.py index 3b7312a4f378..4d08b6222b93 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -244,6 +244,7 @@ VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS: bool = False + VLLM_ROCM_HIP_ONLINE_TUNING: bool = False def get_default_cache_root(): @@ -1026,6 +1027,11 @@ def _get_or_set_default() -> str: "VLLM_LOG_BATCHSIZE_INTERVAL": lambda: float( os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1") ), + # Whether to use HIP online tuning for ROCm + # By default is disabled. + "VLLM_ROCM_HIP_ONLINE_TUNING": lambda: ( + os.getenv("VLLM_ROCM_HIP_ONLINE_TUNING", "False").lower() in ("true", "1") + ), "VLLM_DISABLE_COMPILE_CACHE": disable_compile_cache, # If set, vllm will run in development mode, which will enable # some additional endpoints for developing and debugging, diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 76be83c0638a..4df93bbde72a 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -89,6 +89,10 @@ def _sync_hip_cuda_env_vars(): # Sync at import time - catches misconfigurations from process start. _sync_hip_cuda_env_vars() +# Enable HIP online tuning early, before hipBLASLt initializes. +if envs.VLLM_ROCM_HIP_ONLINE_TUNING: + os.environ["HIP_ONLINE_TUNING"] = "1" + # AMDSMI utils # Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`, # all the related functions work on real physical device ids. From 7abf91628f9ffcd6cb5b1f9bcb85e17cb9e1fdd6 Mon Sep 17 00:00:00 2001 From: hanlin12 Date: Tue, 17 Mar 2026 13:50:33 +0000 Subject: [PATCH 04/19] fixup suffix of environment variable Signed-off-by: hanlin12 --- vllm/envs.py | 7 ++++--- vllm/platforms/rocm.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index a01880c8175f..c2f50c95eb25 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -245,7 +245,7 @@ VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS: bool = False - VLLM_ROCM_HIP_ONLINE_TUNING: bool = False + VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING: bool = False VLLM_NIXL_EP_MAX_NUM_RANKS: int = 32 @@ -1043,8 +1043,9 @@ def _get_or_set_default() -> str: ), # Whether to use HIP online tuning for ROCm # By default is disabled. - "VLLM_ROCM_HIP_ONLINE_TUNING": lambda: ( - os.getenv("VLLM_ROCM_HIP_ONLINE_TUNING", "False").lower() in ("true", "1") + "VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING", "False").lower() + in ("true", "1") ), "VLLM_DISABLE_COMPILE_CACHE": disable_compile_cache, # If set, vllm will run in development mode, which will enable diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index b67ac5e284c2..27d8921adbee 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -90,7 +90,7 @@ def _sync_hip_cuda_env_vars(): _sync_hip_cuda_env_vars() # Enable HIP online tuning early, before hipBLASLt initializes. -if envs.VLLM_ROCM_HIP_ONLINE_TUNING: +if envs.VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING: os.environ["HIP_ONLINE_TUNING"] = "1" # AMDSMI utils From 47744cf9a9d43640154a7661fac61b576c63589d Mon Sep 17 00:00:00 2001 From: hanlin12 Date: Tue, 7 Apr 2026 09:05:11 +0000 Subject: [PATCH 05/19] add unit test for AITER hipBLASLt online tuning --- tests/rocm/aiter/test_aiter_online_tuning.py | 483 +++++++++++++++++++ 1 file changed, 483 insertions(+) create mode 100644 tests/rocm/aiter/test_aiter_online_tuning.py diff --git a/tests/rocm/aiter/test_aiter_online_tuning.py b/tests/rocm/aiter/test_aiter_online_tuning.py new file mode 100644 index 000000000000..32230027594e --- /dev/null +++ b/tests/rocm/aiter/test_aiter_online_tuning.py @@ -0,0 +1,483 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +hipBLASLt Online Tuning Example +================================ + +This file demonstrates how to use hipBLASLt online tuning in vLLM via aiter's +`hipb_mm` kernel, and explains when/how vLLM triggers it automatically. + +Background +---------- +hipBLASLt is the AMD GEMM library used on ROCm. For a given GEMM shape +(M, N, K), there are tens to hundreds of candidate kernel algorithms. +By default, hipBLASLt uses a heuristic to pick one. Online tuning benchmarks +the candidates at runtime and caches the winner in a CSV file so subsequent +calls skip the search. + +There are two levels at which online tuning can be invoked: + +1. **C++-level (HIP_ONLINE_TUNING env var)** + Intercepted inside `hipbsolgemm.cu` for every call that goes through + `hipblasLtMatmul_sol_wrapper`. This includes `torch.nn.functional.linear` + and `torch._scaled_mm` (the PyTorch ROCm BLAS backend) as well as + aiter's `hipb_mm`. It is limited to decode-phase shapes (N <= 512). + Results are saved to `./hip_online_tuning_res.csv`. + +2. **Python-level (aiter hipb_mm with solution_index)** + Calling `hipb_mm(A, B, solution_index=-1, ...)` lets hipBLASLt choose + via heuristic. Calling it with a specific `solution_index` (found by + `hipb_findallsols` + benchmarking) uses that algorithm directly. + The gradlib GemmTuner does this offline and stores results in + `bf16_tuned_gemm.csv`. + +vLLM Integration +---------------- +Set `VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING=1` to enable C++-level online +tuning for all GEMM calls in vLLM (including `torch.nn.functional.linear`). +This env var must be set before process start; vLLM reads it at import time +and sets `HIP_ONLINE_TUNING=1` before hipBLASLt initialises. + +Usage:: + + # Enable for an entire vLLM server + VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING=1 vllm serve + + # Enable when calling vLLM from Python + VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING=1 python my_inference_script.py + +Running this file:: + + # With C++-level online tuning enabled (recommended for decode shapes): + HIP_ONLINE_TUNING=1 python test_hipblaslt_online_tuning.py + + # Or via the vLLM env var: + VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING=1 python test_hipblaslt_online_tuning.py + + # As a pytest: + HIP_ONLINE_TUNING=1 pytest tests/rocm/aiter/test_hipblaslt_online_tuning.py -v +""" + +import importlib.util +import os + +import pytest +import torch + +# --------------------------------------------------------------------------- +# Skip conditions +# --------------------------------------------------------------------------- +aiter_available = importlib.util.find_spec("aiter") is not None + +try: + from vllm.platforms import current_platform + is_rocm = current_platform.is_rocm() +except Exception: + is_rocm = False + +pytestmark = pytest.mark.skipif( + not (is_rocm and aiter_available), + reason="hipBLASLt online tuning requires ROCm + aiter", +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _init_hipblas(): + """Initialise the hipBLASLt handle (lazy, idempotent).""" + import aiter + aiter.hipb_create_extension() + + +def _make_inputs(m, n, k, dtype=torch.bfloat16, device="cuda"): + """Create random A [M, K] and B [N, K] tensors.""" + A = torch.randn(m, k, dtype=dtype, device=device) + B = torch.randn(n, k, dtype=dtype, device=device) + return A, B + + +def _reference(A, B): + """Compute reference result with torch for correctness check.""" + return torch.nn.functional.linear(A.float(), B.float()).to(A.dtype) + + +# --------------------------------------------------------------------------- +# Example 1: hipb_mm with heuristic (solution_index = -1) +# --------------------------------------------------------------------------- + +def test_hipb_mm_heuristic(): + """ + Demonstrates calling hipb_mm with solution_index=-1 (heuristic mode). + + When HIP_ONLINE_TUNING=1, the *first* call for a new decode shape + (N <= 512) benchmarks up to 32 candidates inside C++ and saves the + winner to ./hip_online_tuning_res.csv. Subsequent calls read the + cached algo_index from the CSV, bypassing the search entirely. + + When HIP_ONLINE_TUNING is not set (or N > 512), hipBLASLt uses its + built-in heuristic without any benchmarking. + """ + import aiter + + _init_hipblas() + + # Typical decode-phase shapes: small M (batch), large N/K + shapes = [ + (1, 4096, 4096), # batch=1 decode + (4, 4096, 4096), # batch=4 decode + (1, 8192, 8192), # batch=1, larger weights + (16, 512, 4096), # N=512, boundary case for online tuning + ] + + online_tuning_active = os.environ.get("HIP_ONLINE_TUNING", "0") in ("1", "true") + if online_tuning_active: + print("\n[INFO] HIP_ONLINE_TUNING is active — first unseen shapes will " + "be benchmarked and saved to ./hip_online_tuning_res.csv") + else: + print("\n[INFO] HIP_ONLINE_TUNING is not set — using heuristic only") + + for m, n, k in shapes: + A, B = _make_inputs(m, n, k) + + # hipb_mm expects B transposed: A [M,K] @ B.T [K,N] → C [M,N] + # solution_index=-1: let hipBLASLt decide (heuristic or online tuning) + C = aiter.hipb_mm(A, B.t(), solution_index=-1) + + ref = _reference(A, B) + assert C.shape == (m, n), f"Expected ({m},{n}), got {C.shape}" + assert torch.allclose(C.float(), ref.float(), atol=0.05, rtol=0.05), \ + f"Numerical mismatch for shape ({m},{n},{k})" + + print(f" ({m:4d}, {n:4d}, {k:4d}) ✓ out={C.shape} dtype={C.dtype}") + + print("[PASS] test_hipb_mm_heuristic") + + +# --------------------------------------------------------------------------- +# Example 2: hipb_mm with a specific solution_index (from findallsols) +# --------------------------------------------------------------------------- + +def test_hipb_mm_explicit_solution(): + """ + Demonstrates the manual workflow: + 1. hipb_findallsols() — enumerate all valid hipBLASLt algorithms. + 2. Benchmark them (simple timing loop here). + 3. Run hipb_mm with the winning solution_index. + + This is what aiter's GemmTuner does offline and stores in bf16_tuned_gemm.csv. + For production use, run the tuner once and let vLLM load the CSV at startup + via AITER_CONFIG_GEMM_BF16. + """ + import aiter + + _init_hipblas() + + m, n, k = 4, 4096, 4096 + A, B = _make_inputs(m, n, k) + + # B must be transposed when passed to hipb_mm / hipb_findallsols + B_t = B.t().contiguous() + + # Step 1: find all valid solutions for this shape + solutions = aiter.hipb_findallsols( + A, B_t, + bias=None, + out_dtype=torch.bfloat16, + scaleA=None, + scaleB=None, + bpreshuffle=False, + ) + assert len(solutions) > 0, "hipb_findallsols returned 0 solutions" + print(f"\n Found {len(solutions)} hipBLASLt solutions for " + f"({m}, {n}, {k}) bf16") + + # Step 2: quick benchmark — pick the fastest + num_warmup, num_iters = 5, 20 + best_idx = solutions[0] + best_us = float("inf") + + for sol in solutions: + # warmup + for _ in range(num_warmup): + aiter.hipb_mm(A, B_t, sol) + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(num_iters): + aiter.hipb_mm(A, B_t, sol) + end.record() + torch.cuda.synchronize() + + elapsed_us = start.elapsed_time(end) * 1000 / num_iters # ms → µs + if elapsed_us < best_us: + best_us = elapsed_us + best_idx = sol + + print(f" Best solution_index={best_idx} ({best_us:.1f} µs)") + + # Step 3: use the winning solution + C = aiter.hipb_mm(A, B_t, best_idx) + ref = _reference(A, B) + + assert C.shape == (m, n) + assert torch.allclose(C.float(), ref.float(), atol=0.05, rtol=0.05), \ + "Numerical mismatch with best solution" + + print("[PASS] test_hipb_mm_explicit_solution") + + +# --------------------------------------------------------------------------- +# Example 3: Verify the online-tuning CSV cache is populated +# --------------------------------------------------------------------------- + +def test_hip_online_tuning_csv_populated(): + """ + When HIP_ONLINE_TUNING=1, calling hipb_mm for a new decode shape + (N <= 512) should write a row to ./hip_online_tuning_res.csv. + + This test verifies the file is created and that the row for the + shape we tested is present. + + Skip automatically when HIP_ONLINE_TUNING is not set, since the + CSV will not be written in that case. + """ + if os.environ.get("HIP_ONLINE_TUNING", "0") not in ("1", "true"): + pytest.skip("HIP_ONLINE_TUNING is not set — CSV cache is not written") + + import csv + + import aiter + + _init_hipblas() + + # Use a decode shape (N <= 512) to trigger online tuning + m, n, k = 1, 256, 4096 + A, B = _make_inputs(m, n, k) + + cache_file = "./hip_online_tuning_res.csv" + + # Remove the cache entry for this shape if it exists, so we exercise + # the actual tuning path (not just the cache-hit path). + # In production you would never do this — just leave the CSV intact. + _remove_csv_row(cache_file, m, n, k) + + # First call: triggers benchmarking + writes CSV + C = aiter.hipb_mm(A, B.t(), solution_index=-1) + torch.cuda.synchronize() + + assert os.path.exists(cache_file), \ + f"Expected {cache_file} to be created by online tuning" + + # Verify a row for (m, n, k) appears in the CSV + found = _find_csv_row(cache_file, m, n, k) + assert found, \ + f"No row for ({m},{n},{k}) found in {cache_file}" + print(f"\n Cache row for ({m},{n},{k}): {found}") + + # Second call: cache-hit path (no benchmarking) + C2 = aiter.hipb_mm(A, B.t(), solution_index=-1) + torch.cuda.synchronize() + + ref = _reference(A, B) + assert torch.allclose(C.float(), ref.float(), atol=0.05, rtol=0.05) + assert torch.allclose(C2.float(), ref.float(), atol=0.05, rtol=0.05) + + print("[PASS] test_hip_online_tuning_csv_populated") + + +# --------------------------------------------------------------------------- +# Example 4: vLLM integration via VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING +# --------------------------------------------------------------------------- + +def test_vllm_env_var_sets_hip_online_tuning(): + """ + Demonstrates that vLLM's VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING=1 + propagates to HIP_ONLINE_TUNING=1 at the C++ level. + + In normal usage the env var must be set *before* the process starts + (because vllm/platforms/rocm.py reads and forwards it at import time, + before hipBLASLt is initialised). This test checks the forwarding + logic in isolation. + + To exercise the full end-to-end path, start vLLM like this:: + + VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING=1 vllm serve + + or:: + + VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING=1 python -c " + import vllm + llm = vllm.LLM('meta-llama/Llama-3.1-8B') + out = llm.generate(['Hello']) + print(out[0].outputs[0].text) + " + + The first decode requests for each unique (M, N, K) shape will trigger + online tuning (≈ a few seconds). Results persist in hip_online_tuning_res.csv, + so subsequent runs are instant. + """ + # vllm/platforms/rocm.py executes this at import time: + # if envs.VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING: + # os.environ["HIP_ONLINE_TUNING"] = "1" + # + # We simulate that forwarding here and confirm HIP_ONLINE_TUNING is set. + + import vllm.envs as envs + + if envs.VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING: + assert os.environ.get("HIP_ONLINE_TUNING") == "1", ( + "VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING=1 was set but " + "HIP_ONLINE_TUNING was not forwarded to the environment. " + "Make sure vllm.platforms.rocm is imported before hipBLASLt " + "is initialised." + ) + print("\n VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING=1 → " + "HIP_ONLINE_TUNING=1 ✓") + else: + print("\n VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING is not set " + "(HIP_ONLINE_TUNING will not be enabled via vLLM)") + + print("[PASS] test_vllm_env_var_sets_hip_online_tuning") + + +# --------------------------------------------------------------------------- +# Example 5: FP8 row-wise scaled GEMM with hipb_mm +# --------------------------------------------------------------------------- + +def test_hipb_mm_fp8_rowwise(): + """ + Demonstrates hipb_mm with FP8 inputs and row-wise scaling, + which is used for quantised inference on MI300 (gfx942) and + MI350 (gfx950). + + Row-wise scaling requires hipBLASLt >= 1.0 (ROCm 7.0+). + Online tuning also applies here when HIP_ONLINE_TUNING=1 and N <= 512. + """ + import aiter + from aiter import dtypes + + _init_hipblas() + + try: + fp8_dtype = torch.float8_e4m3fnuz # MI300 native FP8 + except AttributeError: + pytest.skip("torch.float8_e4m3fnuz not available") + + m, n, k = 4, 512, 4096 + + # Quantise inputs to FP8 + A_bf16 = torch.randn(m, k, dtype=torch.bfloat16, device="cuda") + B_bf16 = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") + + A_fp8, x_scale = aiter.pertoken_quant(A_bf16, quant_dtype=fp8_dtype) + B_fp8, w_scale = aiter.pertoken_quant(B_bf16, quant_dtype=fp8_dtype) + + # x_scale shape: [M, 1], w_scale shape: [N, 1] + # hipb_mm expects scaleB to be transposed → [1, N] + C = aiter.hipb_mm( + A_fp8, + B_fp8.t(), # [K, N] + solution_index=-1, + out_dtype=torch.bfloat16, + scaleA=x_scale, + scaleB=w_scale.t(), # [1, N] + ) + + assert C.shape == (m, n), f"Expected ({m},{n}), got {C.shape}" + assert C.dtype == torch.bfloat16 + + # Reference: dequantise then matmul + ref = (A_bf16.float() @ B_bf16.float().t()).bfloat16() + assert torch.allclose(C.float(), ref.float(), atol=0.5, rtol=0.1), \ + "FP8 result deviates too far from bf16 reference" + + print(f"\n FP8 rowwise ({m},{n},{k}) ✓ out={C.shape}") + print("[PASS] test_hipb_mm_fp8_rowwise") + + +# --------------------------------------------------------------------------- +# CSV helpers (used by test_hip_online_tuning_csv_populated) +# --------------------------------------------------------------------------- + +def _find_csv_row(path: str, m: int, n: int, k: int) -> dict | None: + """Return the first CSV row whose m/n/k fields match, or None.""" + if not os.path.exists(path): + return None + import csv as _csv + with open(path, newline="") as f: + reader = _csv.DictReader(f) + for row in reader: + try: + if (int(row.get("m", -1)) == m + and int(row.get("n", -1)) == n + and int(row.get("k", -1)) == k): + return dict(row) + except (ValueError, KeyError): + continue + return None + + +def _remove_csv_row(path: str, m: int, n: int, k: int) -> None: + """Remove rows matching (m, n, k) from the CSV (for test repeatability).""" + if not os.path.exists(path): + return + import csv as _csv + rows = [] + with open(path, newline="") as f: + reader = _csv.DictReader(f) + fieldnames = reader.fieldnames + for row in reader: + try: + if not (int(row.get("m", -1)) == m + and int(row.get("n", -1)) == n + and int(row.get("k", -1)) == k): + rows.append(row) + except (ValueError, KeyError): + rows.append(row) + if fieldnames: + with open(path, "w", newline="") as f: + writer = _csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + writer.writerows(rows) + + +# --------------------------------------------------------------------------- +# Standalone entry-point (run without pytest) +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + if not (is_rocm and aiter_available): + print("ERROR: requires ROCm platform with aiter installed") + raise SystemExit(1) + + print("=" * 60) + print("hipBLASLt Online Tuning Demo") + print("=" * 60) + print(f"HIP_ONLINE_TUNING = {os.environ.get('HIP_ONLINE_TUNING', '(not set)')}") + + try: + import vllm.envs as envs + print(f"VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING = " + f"{envs.VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING}") + except ImportError: + pass + + print() + + test_hipb_mm_heuristic() + print() + test_hipb_mm_explicit_solution() + print() + test_hip_online_tuning_csv_populated() + print() + test_vllm_env_var_sets_hip_online_tuning() + print() + test_hipb_mm_fp8_rowwise() + + print() + print("=" * 60) + print("All tests passed.") + print("=" * 60) From eccbcedef8501e9a4db15d71a7a1d4566f985d86 Mon Sep 17 00:00:00 2001 From: hanlin12 Date: Wed, 8 Apr 2026 03:04:46 +0000 Subject: [PATCH 06/19] fix typos in comment --- tests/rocm/aiter/test_aiter_online_tuning.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tests/rocm/aiter/test_aiter_online_tuning.py b/tests/rocm/aiter/test_aiter_online_tuning.py index 32230027594e..0183cdd927ca 100644 --- a/tests/rocm/aiter/test_aiter_online_tuning.py +++ b/tests/rocm/aiter/test_aiter_online_tuning.py @@ -19,10 +19,7 @@ 1. **C++-level (HIP_ONLINE_TUNING env var)** Intercepted inside `hipbsolgemm.cu` for every call that goes through - `hipblasLtMatmul_sol_wrapper`. This includes `torch.nn.functional.linear` - and `torch._scaled_mm` (the PyTorch ROCm BLAS backend) as well as - aiter's `hipb_mm`. It is limited to decode-phase shapes (N <= 512). - Results are saved to `./hip_online_tuning_res.csv`. + `hipblasLtMatmul_sol_wrapper`. 2. **Python-level (aiter hipb_mm with solution_index)** Calling `hipb_mm(A, B, solution_index=-1, ...)` lets hipBLASLt choose From f33bfe551ea5549fe26b5367e5f8b5852eb6b5b1 Mon Sep 17 00:00:00 2001 From: hanlin12 Date: Tue, 21 Apr 2026 02:42:15 +0000 Subject: [PATCH 07/19] Add AITER hipBLASLt GEMM kernel in vLLM Signed-off-by: hanlin12 --- tests/rocm/aiter/test_aiter_online_tuning.py | 156 ++++++++++++------ vllm/_aiter_ops.py | 65 ++++++++ vllm/envs.py | 5 + .../model_executor/kernels/linear/__init__.py | 3 + .../kernels/linear/scaled_mm/aiter.py | 85 ++++++++++ vllm/platforms/rocm.py | 1 + 6 files changed, 268 insertions(+), 47 deletions(-) diff --git a/tests/rocm/aiter/test_aiter_online_tuning.py b/tests/rocm/aiter/test_aiter_online_tuning.py index 0183cdd927ca..88416d602a5b 100644 --- a/tests/rocm/aiter/test_aiter_online_tuning.py +++ b/tests/rocm/aiter/test_aiter_online_tuning.py @@ -68,6 +68,7 @@ try: from vllm.platforms import current_platform + is_rocm = current_platform.is_rocm() except Exception: is_rocm = False @@ -81,9 +82,11 @@ # Helpers # --------------------------------------------------------------------------- + def _init_hipblas(): """Initialise the hipBLASLt handle (lazy, idempotent).""" import aiter + aiter.hipb_create_extension() @@ -103,6 +106,7 @@ def _reference(A, B): # Example 1: hipb_mm with heuristic (solution_index = -1) # --------------------------------------------------------------------------- + def test_hipb_mm_heuristic(): """ Demonstrates calling hipb_mm with solution_index=-1 (heuristic mode). @@ -121,16 +125,18 @@ def test_hipb_mm_heuristic(): # Typical decode-phase shapes: small M (batch), large N/K shapes = [ - (1, 4096, 4096), # batch=1 decode - (4, 4096, 4096), # batch=4 decode - (1, 8192, 8192), # batch=1, larger weights - (16, 512, 4096), # N=512, boundary case for online tuning + (1, 4096, 4096), # batch=1 decode + (4, 4096, 4096), # batch=4 decode + (1, 8192, 8192), # batch=1, larger weights + (16, 512, 4096), # N=512, boundary case for online tuning ] online_tuning_active = os.environ.get("HIP_ONLINE_TUNING", "0") in ("1", "true") if online_tuning_active: - print("\n[INFO] HIP_ONLINE_TUNING is active — first unseen shapes will " - "be benchmarked and saved to ./hip_online_tuning_res.csv") + print( + "\n[INFO] HIP_ONLINE_TUNING is active — first unseen shapes will " + "be benchmarked and saved to ./hip_online_tuning_res.csv" + ) else: print("\n[INFO] HIP_ONLINE_TUNING is not set — using heuristic only") @@ -143,8 +149,9 @@ def test_hipb_mm_heuristic(): ref = _reference(A, B) assert C.shape == (m, n), f"Expected ({m},{n}), got {C.shape}" - assert torch.allclose(C.float(), ref.float(), atol=0.05, rtol=0.05), \ + assert torch.allclose(C.float(), ref.float(), atol=0.05, rtol=0.05), ( f"Numerical mismatch for shape ({m},{n},{k})" + ) print(f" ({m:4d}, {n:4d}, {k:4d}) ✓ out={C.shape} dtype={C.dtype}") @@ -155,6 +162,7 @@ def test_hipb_mm_heuristic(): # Example 2: hipb_mm with a specific solution_index (from findallsols) # --------------------------------------------------------------------------- + def test_hipb_mm_explicit_solution(): """ Demonstrates the manual workflow: @@ -178,7 +186,8 @@ def test_hipb_mm_explicit_solution(): # Step 1: find all valid solutions for this shape solutions = aiter.hipb_findallsols( - A, B_t, + A, + B_t, bias=None, out_dtype=torch.bfloat16, scaleA=None, @@ -186,8 +195,7 @@ def test_hipb_mm_explicit_solution(): bpreshuffle=False, ) assert len(solutions) > 0, "hipb_findallsols returned 0 solutions" - print(f"\n Found {len(solutions)} hipBLASLt solutions for " - f"({m}, {n}, {k}) bf16") + print(f"\n Found {len(solutions)} hipBLASLt solutions for ({m}, {n}, {k}) bf16") # Step 2: quick benchmark — pick the fastest num_warmup, num_iters = 5, 20 @@ -198,19 +206,19 @@ def test_hipb_mm_explicit_solution(): # warmup for _ in range(num_warmup): aiter.hipb_mm(A, B_t, sol) - torch.cuda.synchronize() + torch.accelerator.synchronize() start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) start.record() for _ in range(num_iters): aiter.hipb_mm(A, B_t, sol) end.record() - torch.cuda.synchronize() + torch.accelerator.synchronize() elapsed_us = start.elapsed_time(end) * 1000 / num_iters # ms → µs if elapsed_us < best_us: - best_us = elapsed_us + best_us = elapsed_us best_idx = sol print(f" Best solution_index={best_idx} ({best_us:.1f} µs)") @@ -220,8 +228,9 @@ def test_hipb_mm_explicit_solution(): ref = _reference(A, B) assert C.shape == (m, n) - assert torch.allclose(C.float(), ref.float(), atol=0.05, rtol=0.05), \ + assert torch.allclose(C.float(), ref.float(), atol=0.05, rtol=0.05), ( "Numerical mismatch with best solution" + ) print("[PASS] test_hipb_mm_explicit_solution") @@ -230,6 +239,7 @@ def test_hipb_mm_explicit_solution(): # Example 3: Verify the online-tuning CSV cache is populated # --------------------------------------------------------------------------- + def test_hip_online_tuning_csv_populated(): """ When HIP_ONLINE_TUNING=1, calling hipb_mm for a new decode shape @@ -244,8 +254,6 @@ def test_hip_online_tuning_csv_populated(): if os.environ.get("HIP_ONLINE_TUNING", "0") not in ("1", "true"): pytest.skip("HIP_ONLINE_TUNING is not set — CSV cache is not written") - import csv - import aiter _init_hipblas() @@ -263,23 +271,23 @@ def test_hip_online_tuning_csv_populated(): # First call: triggers benchmarking + writes CSV C = aiter.hipb_mm(A, B.t(), solution_index=-1) - torch.cuda.synchronize() + torch.accelerator.synchronize() - assert os.path.exists(cache_file), \ + assert os.path.exists(cache_file), ( f"Expected {cache_file} to be created by online tuning" + ) # Verify a row for (m, n, k) appears in the CSV found = _find_csv_row(cache_file, m, n, k) - assert found, \ - f"No row for ({m},{n},{k}) found in {cache_file}" + assert found, f"No row for ({m},{n},{k}) found in {cache_file}" print(f"\n Cache row for ({m},{n},{k}): {found}") # Second call: cache-hit path (no benchmarking) C2 = aiter.hipb_mm(A, B.t(), solution_index=-1) - torch.cuda.synchronize() + torch.accelerator.synchronize() ref = _reference(A, B) - assert torch.allclose(C.float(), ref.float(), atol=0.05, rtol=0.05) + assert torch.allclose(C.float(), ref.float(), atol=0.05, rtol=0.05) assert torch.allclose(C2.float(), ref.float(), atol=0.05, rtol=0.05) print("[PASS] test_hip_online_tuning_csv_populated") @@ -289,10 +297,11 @@ def test_hip_online_tuning_csv_populated(): # Example 4: vLLM integration via VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING # --------------------------------------------------------------------------- + def test_vllm_env_var_sets_hip_online_tuning(): """ - Demonstrates that vLLM's VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING=1 - propagates to HIP_ONLINE_TUNING=1 at the C++ level. + Demonstrates that vLLM's ROCm AITER knobs that rely on hipBLASLt online + tuning propagate to HIP_ONLINE_TUNING=1 at the C++ level. In normal usage the env var must be set *before* the process starts (because vllm/platforms/rocm.py reads and forwards it at import time, @@ -326,16 +335,23 @@ def test_vllm_env_var_sets_hip_online_tuning(): if envs.VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING: assert os.environ.get("HIP_ONLINE_TUNING") == "1", ( - "VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING=1 was set but " + "VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING was set but " "HIP_ONLINE_TUNING was not forwarded to the environment. " "Make sure vllm.platforms.rocm is imported before hipBLASLt " "is initialised." ) - print("\n VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING=1 → " - "HIP_ONLINE_TUNING=1 ✓") + print("\n ROCm AITER HIP tuning env var → HIP_ONLINE_TUNING=1 ✓") + elif envs.VLLM_ROCM_AITER_FORCE_HIPBMM_LINEAR: + assert os.environ.get("HIP_ONLINE_TUNING") != "1", ( + "VLLM_ROCM_AITER_FORCE_HIPBMM_LINEAR should not enable " + "HIP_ONLINE_TUNING by default." + ) + print("\n Force hipb_mm linear is set without HIP online tuning ✓") else: - print("\n VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING is not set " - "(HIP_ONLINE_TUNING will not be enabled via vLLM)") + print( + "\n No ROCm AITER HIP tuning env var is set " + "(HIP_ONLINE_TUNING will not be enabled via vLLM)" + ) print("[PASS] test_vllm_env_var_sets_hip_online_tuning") @@ -344,6 +360,7 @@ def test_vllm_env_var_sets_hip_online_tuning(): # Example 5: FP8 row-wise scaled GEMM with hipb_mm # --------------------------------------------------------------------------- + def test_hipb_mm_fp8_rowwise(): """ Demonstrates hipb_mm with FP8 inputs and row-wise scaling, @@ -354,12 +371,11 @@ def test_hipb_mm_fp8_rowwise(): Online tuning also applies here when HIP_ONLINE_TUNING=1 and N <= 512. """ import aiter - from aiter import dtypes _init_hipblas() try: - fp8_dtype = torch.float8_e4m3fnuz # MI300 native FP8 + fp8_dtype = torch.float8_e4m3fnuz # MI300 native FP8 except AttributeError: pytest.skip("torch.float8_e4m3fnuz not available") @@ -376,7 +392,7 @@ def test_hipb_mm_fp8_rowwise(): # hipb_mm expects scaleB to be transposed → [1, N] C = aiter.hipb_mm( A_fp8, - B_fp8.t(), # [K, N] + B_fp8.t(), # [K, N] solution_index=-1, out_dtype=torch.bfloat16, scaleA=x_scale, @@ -386,31 +402,71 @@ def test_hipb_mm_fp8_rowwise(): assert C.shape == (m, n), f"Expected ({m},{n}), got {C.shape}" assert C.dtype == torch.bfloat16 - # Reference: dequantise then matmul - ref = (A_bf16.float() @ B_bf16.float().t()).bfloat16() - assert torch.allclose(C.float(), ref.float(), atol=0.5, rtol=0.1), \ - "FP8 result deviates too far from bf16 reference" + # Reference: dequantise the quantized FP8 inputs using their row-wise scales, + # then run the same linear algebra in fp32. + ref = torch.nn.functional.linear( + A_fp8.float() * x_scale.float(), + B_fp8.float() * w_scale.float(), + ).bfloat16() + assert torch.allclose(C.float(), ref.float(), atol=0.5, rtol=0.1), ( + "FP8 result deviates too far from dequantized FP8 reference" + ) print(f"\n FP8 rowwise ({m},{n},{k}) ✓ out={C.shape}") print("[PASS] test_hipb_mm_fp8_rowwise") +# --------------------------------------------------------------------------- +# Example 6: vLLM kernel gating for AiterHipbMMPerTokenFp8ScaledMMLinearKernel +# --------------------------------------------------------------------------- + + +def test_aiter_hipb_mm_kernel_requires_force_flag(monkeypatch: pytest.MonkeyPatch): + from vllm._aiter_ops import rocm_aiter_ops + from vllm.model_executor.kernels.linear.scaled_mm.aiter import ( + AiterHipbMMPerTokenFp8ScaledMMLinearKernel, + ) + + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "1") + monkeypatch.delenv("VLLM_ROCM_AITER_FORCE_HIPBMM_LINEAR", raising=False) + rocm_aiter_ops.refresh_env_variables() + + try: + is_supported, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.is_supported() + assert not is_supported + assert reason is not None + assert "VLLM_ROCM_AITER_FORCE_HIPBMM_LINEAR=1" in reason + print("[PASS] test_aiter_hipb_mm_kernel_requires_force_flag") + finally: + monkeypatch.undo() + rocm_aiter_ops.refresh_env_variables() + + # --------------------------------------------------------------------------- # CSV helpers (used by test_hip_online_tuning_csv_populated) # --------------------------------------------------------------------------- + +def _csv_row_matches_shape(row: dict, m: int, n: int, k: int) -> bool: + """hipBLASLt tuning CSVs may store the logical M/N order swapped.""" + row_m = int(row.get("m", -1)) + row_n = int(row.get("n", -1)) + row_k = int(row.get("k", -1)) + return row_k == k and ((row_m == m and row_n == n) or (row_m == n and row_n == m)) + + def _find_csv_row(path: str, m: int, n: int, k: int) -> dict | None: """Return the first CSV row whose m/n/k fields match, or None.""" if not os.path.exists(path): return None import csv as _csv + with open(path, newline="") as f: - reader = _csv.DictReader(f) + reader = _csv.DictReader(f, skipinitialspace=True) for row in reader: try: - if (int(row.get("m", -1)) == m - and int(row.get("n", -1)) == n - and int(row.get("k", -1)) == k): + if _csv_row_matches_shape(row, m, n, k): return dict(row) except (ValueError, KeyError): continue @@ -422,15 +478,14 @@ def _remove_csv_row(path: str, m: int, n: int, k: int) -> None: if not os.path.exists(path): return import csv as _csv + rows = [] with open(path, newline="") as f: - reader = _csv.DictReader(f) + reader = _csv.DictReader(f, skipinitialspace=True) fieldnames = reader.fieldnames for row in reader: try: - if not (int(row.get("m", -1)) == m - and int(row.get("n", -1)) == n - and int(row.get("k", -1)) == k): + if not _csv_row_matches_shape(row, m, n, k): rows.append(row) except (ValueError, KeyError): rows.append(row) @@ -457,8 +512,11 @@ def _remove_csv_row(path: str, m: int, n: int, k: int) -> None: try: import vllm.envs as envs - print(f"VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING = " - f"{envs.VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING}") + + print( + f"VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING = " + f"{envs.VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING}" + ) except ImportError: pass @@ -473,6 +531,10 @@ def _remove_csv_row(path: str, m: int, n: int, k: int) -> None: test_vllm_env_var_sets_hip_online_tuning() print() test_hipb_mm_fp8_rowwise() + print() + test_aiter_hipb_mm_kernel_requires_force_flag(pytest.MonkeyPatch()) + print() + print() print("=" * 60) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 55d6d1297a8c..f967d5674cdf 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -24,6 +24,16 @@ # on ROCm the fp8_dtype always calls is_fp8_fnuz # which is a host op, so we cache it once here. FP8_DTYPE = current_platform.fp8_dtype() +_HIPB_MM_INITIALIZED_DEVICES: set[int] = set() + + +def _ensure_hipb_mm_extension_initialized() -> None: + import aiter + + device = torch.accelerator.current_device_index() + if device not in _HIPB_MM_INITIALIZED_DEVICES: + aiter.hipb_create_extension() + _HIPB_MM_INITIALIZED_DEVICES.add(device) def is_aiter_found() -> bool: @@ -557,6 +567,44 @@ def _rocm_aiter_preshuffled_per_token_w8a8_gemm_fake( return torch.empty(m, n, dtype=output_dtype, device=A.device) +def _rocm_aiter_hipb_mm_fp8_impl( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + from aiter import hipb_mm + + _ensure_hipb_mm_extension_initialized() + scale_b = Bs.t().contiguous() if Bs.ndim > 1 else Bs + return hipb_mm( + A, + B.t().contiguous(), + solution_index=-1, + bias=bias, + out_dtype=output_dtype, + scaleA=As, + scaleB=scale_b, + scaleOut=None, + bpreshuffle=True, + ) + + +def _rocm_aiter_hipb_mm_fp8_fake( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + m = A.shape[0] + n = B.shape[0] + return torch.empty(m, n, dtype=output_dtype, device=A.device) + + def _rocm_aiter_triton_gemm_a8w8_blockscale_impl( A: torch.Tensor, B: torch.Tensor, @@ -1382,6 +1430,12 @@ def register_ops_once() -> None: fake_impl=_rocm_aiter_preshuffled_per_token_w8a8_gemm_fake, ) + direct_register_custom_op( + op_name="rocm_aiter_hipb_mm_fp8", + op_func=_rocm_aiter_hipb_mm_fp8_impl, + fake_impl=_rocm_aiter_hipb_mm_fp8_fake, + ) + direct_register_custom_op( op_name="rocm_aiter_triton_gemm_a8w8_blockscale", op_func=_rocm_aiter_triton_gemm_a8w8_blockscale_impl, @@ -1578,6 +1632,17 @@ def preshuffled_per_token_w8a8_gemm( A, B, As, Bs, bias, output_dtype ) + @staticmethod + def hipb_mm_fp8( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.bfloat16, + ) -> torch.Tensor: + return torch.ops.vllm.rocm_aiter_hipb_mm_fp8(A, B, As, Bs, bias, output_dtype) + @staticmethod def triton_gemm_a8w8_blockscale( A: torch.Tensor, diff --git a/vllm/envs.py b/vllm/envs.py index f37ec338b895..eea2e8e19ac6 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -111,6 +111,7 @@ VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False VLLM_ROCM_USE_AITER_LINEAR: bool = True + VLLM_ROCM_AITER_FORCE_HIPBMM_LINEAR: bool = False VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True @@ -987,6 +988,10 @@ def _get_or_set_default() -> str: "VLLM_ROCM_USE_AITER_LINEAR": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in ("true", "1") ), + "VLLM_ROCM_AITER_FORCE_HIPBMM_LINEAR": lambda: ( + os.getenv("VLLM_ROCM_AITER_FORCE_HIPBMM_LINEAR", "False").lower() + in ("true", "1") + ), # Whether to use aiter moe ops. # By default is enabled. "VLLM_ROCM_USE_AITER_MOE": lambda: ( diff --git a/vllm/model_executor/kernels/linear/__init__.py b/vllm/model_executor/kernels/linear/__init__.py index 5d513f767f03..6dbc1eaf586b 100644 --- a/vllm/model_executor/kernels/linear/__init__.py +++ b/vllm/model_executor/kernels/linear/__init__.py @@ -105,6 +105,7 @@ ) from vllm.model_executor.kernels.linear.scaled_mm.aiter import ( AiterFp8BlockScaledMMKernel, + AiterHipbMMPerTokenFp8ScaledMMLinearKernel, AiterInt8ScaledMMLinearKernel, AiterPerTokenFp8ScaledMMLinearKernel, AiterPreshuffledPerTokenFp8ScaledMMLinearKernel, @@ -167,6 +168,7 @@ ChannelWiseTorchFP8ScaledMMLinearKernel, ], PlatformEnum.ROCM: [ + AiterHipbMMPerTokenFp8ScaledMMLinearKernel, AiterPreshuffledPerTokenFp8ScaledMMLinearKernel, AiterPerTokenFp8ScaledMMLinearKernel, ROCmFP8ScaledMMLinearKernel, @@ -743,6 +745,7 @@ def register_linear_kernel( "FP8ScaledMMLinearLayerConfig", "Int8ScaledMMLinearLayerConfig", "ScaledMMLinearLayerConfig", + "AiterHipbMMPerTokenFp8ScaledMMLinearKernel", "AiterPreshuffledPerTokenFp8ScaledMMLinearKernel", "AiterPerTokenFp8ScaledMMLinearKernel", "NvFp4LinearKernel", diff --git a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py index 8a8650d22135..39ab7be1524a 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py @@ -4,6 +4,7 @@ import torch +import vllm.envs as envs from vllm import _custom_ops as ops from vllm._aiter_ops import ( rocm_aiter_ops, @@ -212,6 +213,90 @@ def apply_scaled_mm( ) +class AiterHipbMMPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): + @classmethod + def is_supported( + cls, compute_capability: int | None = None + ) -> tuple[bool, str | None]: + if not envs.VLLM_ROCM_AITER_FORCE_HIPBMM_LINEAR: + return False, "requires VLLM_ROCM_AITER_FORCE_HIPBMM_LINEAR=1." + + if not current_platform.is_rocm(): + return False, "requires ROCm." + if not rocm_aiter_ops.is_linear_fp8_enabled(): + return ( + False, + "requires setting `VLLM_ROCM_USE_AITER=1` " + "and `VLLM_ROCM_USE_AITER_LINEAR=1`. " + "`VLLM_ROCM_USE_AITER_LINEAR` default is True.", + ) + try: + import aiter # noqa: F401 + except Exception: + return False, "requires aiter library to be installed." + + if not hasattr(aiter, "hipb_mm"): + return False, "requires aiter hipb_mm support." + + return True, None + + @classmethod + def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | None]: + is_ptpc = ( + c.activation_quant_key.scale.group_shape.is_per_token() + and c.weight_quant_key.scale.group_shape.is_per_channel() + ) + if c.weight_shape is None: + return False, "weight_shape is required for Aiter kernels" + N, K = c.weight_shape + fp8_dtype = current_platform.fp8_dtype() + + if c.out_dtype is not torch.bfloat16: + return False, "requires bfloat16 output dtype." + + if not is_ptpc: + return ( + False, + "requires per token activation scales and per channel weight scales.", + ) + + if not (N >= 16 and N % 16 == 0 and K % 16 == 0): + return ( + False, + f"requires N >= 16 and N/K divisible by 16, received N={N} and K={K}.", + ) + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + w_name, *_ = self.layer_param_names + w, *_ = self._get_layer_params(layer) + + replace_parameter( + layer, + w_name, + torch.nn.Parameter( + rocm_aiter_ops.shuffle_weight(w.t().contiguous()).data, + requires_grad=False, + ), + ) + + def apply_scaled_mm( + self, + *, + A: torch.Tensor, + B: torch.Tensor, + out_dtype: torch.dtype, + As: torch.Tensor, + Bs: torch.Tensor, + bias: torch.Tensor | None, + output_shape: list, + ) -> torch.Tensor: + return rocm_aiter_ops.hipb_mm_fp8(A, B, As, Bs, bias, out_dtype).view( + *output_shape + ) + + class AiterPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): @classmethod def is_supported( diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 65354cd421c5..044e5d6a5ca4 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -129,6 +129,7 @@ def _sync_hip_cuda_env_vars(): _sync_hip_cuda_env_vars() # Enable HIP online tuning early, before hipBLASLt initializes. +# Forcing the hipb_mm linear kernel does not imply online tuning. if envs.VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING: os.environ["HIP_ONLINE_TUNING"] = "1" From ea9cef5cddf33a63a64d77d50ce27b882a5d47c5 Mon Sep 17 00:00:00 2001 From: Han Lin Date: Tue, 21 Apr 2026 11:28:43 +0800 Subject: [PATCH 08/19] Update vllm/model_executor/kernels/linear/scaled_mm/aiter.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Han Lin --- vllm/model_executor/kernels/linear/scaled_mm/aiter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py index 39ab7be1524a..da107e44365d 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py @@ -263,7 +263,7 @@ def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | Non if not (N >= 16 and N % 16 == 0 and K % 16 == 0): return ( False, - f"requires N >= 16 and N/K divisible by 16, received N={N} and K={K}.", + f"requires N >= 16 and both N and K divisible by 16, received N={N} and K={K}.", ) return True, None From d44e5b84d5d9e86c65671e6a21d81f9db2582592 Mon Sep 17 00:00:00 2001 From: Han Lin Date: Tue, 21 Apr 2026 15:40:08 +0800 Subject: [PATCH 09/19] Update vllm/_aiter_ops.py Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: Han Lin --- vllm/_aiter_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index c2545a8496c1..ba51d1f76bc4 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -617,7 +617,7 @@ def _rocm_aiter_hipb_mm_fp8_fake( output_dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: m = A.shape[0] - n = B.shape[0] + n = B.shape[1] return torch.empty(m, n, dtype=output_dtype, device=A.device) From 226e0301c297491dad05ff996f174a9a2d464bd9 Mon Sep 17 00:00:00 2001 From: hanlin12 Date: Tue, 21 Apr 2026 09:33:55 +0000 Subject: [PATCH 10/19] Remove the contiguous() after preshuffle Signed-off-by: hanlin12 --- tests/rocm/aiter/test_aiter_online_tuning.py | 542 ------------------ vllm/_aiter_ops.py | 2 +- .../kernels/linear/scaled_mm/aiter.py | 1 + 3 files changed, 2 insertions(+), 543 deletions(-) delete mode 100644 tests/rocm/aiter/test_aiter_online_tuning.py diff --git a/tests/rocm/aiter/test_aiter_online_tuning.py b/tests/rocm/aiter/test_aiter_online_tuning.py deleted file mode 100644 index 88416d602a5b..000000000000 --- a/tests/rocm/aiter/test_aiter_online_tuning.py +++ /dev/null @@ -1,542 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -hipBLASLt Online Tuning Example -================================ - -This file demonstrates how to use hipBLASLt online tuning in vLLM via aiter's -`hipb_mm` kernel, and explains when/how vLLM triggers it automatically. - -Background ----------- -hipBLASLt is the AMD GEMM library used on ROCm. For a given GEMM shape -(M, N, K), there are tens to hundreds of candidate kernel algorithms. -By default, hipBLASLt uses a heuristic to pick one. Online tuning benchmarks -the candidates at runtime and caches the winner in a CSV file so subsequent -calls skip the search. - -There are two levels at which online tuning can be invoked: - -1. **C++-level (HIP_ONLINE_TUNING env var)** - Intercepted inside `hipbsolgemm.cu` for every call that goes through - `hipblasLtMatmul_sol_wrapper`. - -2. **Python-level (aiter hipb_mm with solution_index)** - Calling `hipb_mm(A, B, solution_index=-1, ...)` lets hipBLASLt choose - via heuristic. Calling it with a specific `solution_index` (found by - `hipb_findallsols` + benchmarking) uses that algorithm directly. - The gradlib GemmTuner does this offline and stores results in - `bf16_tuned_gemm.csv`. - -vLLM Integration ----------------- -Set `VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING=1` to enable C++-level online -tuning for all GEMM calls in vLLM (including `torch.nn.functional.linear`). -This env var must be set before process start; vLLM reads it at import time -and sets `HIP_ONLINE_TUNING=1` before hipBLASLt initialises. - -Usage:: - - # Enable for an entire vLLM server - VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING=1 vllm serve - - # Enable when calling vLLM from Python - VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING=1 python my_inference_script.py - -Running this file:: - - # With C++-level online tuning enabled (recommended for decode shapes): - HIP_ONLINE_TUNING=1 python test_hipblaslt_online_tuning.py - - # Or via the vLLM env var: - VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING=1 python test_hipblaslt_online_tuning.py - - # As a pytest: - HIP_ONLINE_TUNING=1 pytest tests/rocm/aiter/test_hipblaslt_online_tuning.py -v -""" - -import importlib.util -import os - -import pytest -import torch - -# --------------------------------------------------------------------------- -# Skip conditions -# --------------------------------------------------------------------------- -aiter_available = importlib.util.find_spec("aiter") is not None - -try: - from vllm.platforms import current_platform - - is_rocm = current_platform.is_rocm() -except Exception: - is_rocm = False - -pytestmark = pytest.mark.skipif( - not (is_rocm and aiter_available), - reason="hipBLASLt online tuning requires ROCm + aiter", -) - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _init_hipblas(): - """Initialise the hipBLASLt handle (lazy, idempotent).""" - import aiter - - aiter.hipb_create_extension() - - -def _make_inputs(m, n, k, dtype=torch.bfloat16, device="cuda"): - """Create random A [M, K] and B [N, K] tensors.""" - A = torch.randn(m, k, dtype=dtype, device=device) - B = torch.randn(n, k, dtype=dtype, device=device) - return A, B - - -def _reference(A, B): - """Compute reference result with torch for correctness check.""" - return torch.nn.functional.linear(A.float(), B.float()).to(A.dtype) - - -# --------------------------------------------------------------------------- -# Example 1: hipb_mm with heuristic (solution_index = -1) -# --------------------------------------------------------------------------- - - -def test_hipb_mm_heuristic(): - """ - Demonstrates calling hipb_mm with solution_index=-1 (heuristic mode). - - When HIP_ONLINE_TUNING=1, the *first* call for a new decode shape - (N <= 512) benchmarks up to 32 candidates inside C++ and saves the - winner to ./hip_online_tuning_res.csv. Subsequent calls read the - cached algo_index from the CSV, bypassing the search entirely. - - When HIP_ONLINE_TUNING is not set (or N > 512), hipBLASLt uses its - built-in heuristic without any benchmarking. - """ - import aiter - - _init_hipblas() - - # Typical decode-phase shapes: small M (batch), large N/K - shapes = [ - (1, 4096, 4096), # batch=1 decode - (4, 4096, 4096), # batch=4 decode - (1, 8192, 8192), # batch=1, larger weights - (16, 512, 4096), # N=512, boundary case for online tuning - ] - - online_tuning_active = os.environ.get("HIP_ONLINE_TUNING", "0") in ("1", "true") - if online_tuning_active: - print( - "\n[INFO] HIP_ONLINE_TUNING is active — first unseen shapes will " - "be benchmarked and saved to ./hip_online_tuning_res.csv" - ) - else: - print("\n[INFO] HIP_ONLINE_TUNING is not set — using heuristic only") - - for m, n, k in shapes: - A, B = _make_inputs(m, n, k) - - # hipb_mm expects B transposed: A [M,K] @ B.T [K,N] → C [M,N] - # solution_index=-1: let hipBLASLt decide (heuristic or online tuning) - C = aiter.hipb_mm(A, B.t(), solution_index=-1) - - ref = _reference(A, B) - assert C.shape == (m, n), f"Expected ({m},{n}), got {C.shape}" - assert torch.allclose(C.float(), ref.float(), atol=0.05, rtol=0.05), ( - f"Numerical mismatch for shape ({m},{n},{k})" - ) - - print(f" ({m:4d}, {n:4d}, {k:4d}) ✓ out={C.shape} dtype={C.dtype}") - - print("[PASS] test_hipb_mm_heuristic") - - -# --------------------------------------------------------------------------- -# Example 2: hipb_mm with a specific solution_index (from findallsols) -# --------------------------------------------------------------------------- - - -def test_hipb_mm_explicit_solution(): - """ - Demonstrates the manual workflow: - 1. hipb_findallsols() — enumerate all valid hipBLASLt algorithms. - 2. Benchmark them (simple timing loop here). - 3. Run hipb_mm with the winning solution_index. - - This is what aiter's GemmTuner does offline and stores in bf16_tuned_gemm.csv. - For production use, run the tuner once and let vLLM load the CSV at startup - via AITER_CONFIG_GEMM_BF16. - """ - import aiter - - _init_hipblas() - - m, n, k = 4, 4096, 4096 - A, B = _make_inputs(m, n, k) - - # B must be transposed when passed to hipb_mm / hipb_findallsols - B_t = B.t().contiguous() - - # Step 1: find all valid solutions for this shape - solutions = aiter.hipb_findallsols( - A, - B_t, - bias=None, - out_dtype=torch.bfloat16, - scaleA=None, - scaleB=None, - bpreshuffle=False, - ) - assert len(solutions) > 0, "hipb_findallsols returned 0 solutions" - print(f"\n Found {len(solutions)} hipBLASLt solutions for ({m}, {n}, {k}) bf16") - - # Step 2: quick benchmark — pick the fastest - num_warmup, num_iters = 5, 20 - best_idx = solutions[0] - best_us = float("inf") - - for sol in solutions: - # warmup - for _ in range(num_warmup): - aiter.hipb_mm(A, B_t, sol) - torch.accelerator.synchronize() - - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() - for _ in range(num_iters): - aiter.hipb_mm(A, B_t, sol) - end.record() - torch.accelerator.synchronize() - - elapsed_us = start.elapsed_time(end) * 1000 / num_iters # ms → µs - if elapsed_us < best_us: - best_us = elapsed_us - best_idx = sol - - print(f" Best solution_index={best_idx} ({best_us:.1f} µs)") - - # Step 3: use the winning solution - C = aiter.hipb_mm(A, B_t, best_idx) - ref = _reference(A, B) - - assert C.shape == (m, n) - assert torch.allclose(C.float(), ref.float(), atol=0.05, rtol=0.05), ( - "Numerical mismatch with best solution" - ) - - print("[PASS] test_hipb_mm_explicit_solution") - - -# --------------------------------------------------------------------------- -# Example 3: Verify the online-tuning CSV cache is populated -# --------------------------------------------------------------------------- - - -def test_hip_online_tuning_csv_populated(): - """ - When HIP_ONLINE_TUNING=1, calling hipb_mm for a new decode shape - (N <= 512) should write a row to ./hip_online_tuning_res.csv. - - This test verifies the file is created and that the row for the - shape we tested is present. - - Skip automatically when HIP_ONLINE_TUNING is not set, since the - CSV will not be written in that case. - """ - if os.environ.get("HIP_ONLINE_TUNING", "0") not in ("1", "true"): - pytest.skip("HIP_ONLINE_TUNING is not set — CSV cache is not written") - - import aiter - - _init_hipblas() - - # Use a decode shape (N <= 512) to trigger online tuning - m, n, k = 1, 256, 4096 - A, B = _make_inputs(m, n, k) - - cache_file = "./hip_online_tuning_res.csv" - - # Remove the cache entry for this shape if it exists, so we exercise - # the actual tuning path (not just the cache-hit path). - # In production you would never do this — just leave the CSV intact. - _remove_csv_row(cache_file, m, n, k) - - # First call: triggers benchmarking + writes CSV - C = aiter.hipb_mm(A, B.t(), solution_index=-1) - torch.accelerator.synchronize() - - assert os.path.exists(cache_file), ( - f"Expected {cache_file} to be created by online tuning" - ) - - # Verify a row for (m, n, k) appears in the CSV - found = _find_csv_row(cache_file, m, n, k) - assert found, f"No row for ({m},{n},{k}) found in {cache_file}" - print(f"\n Cache row for ({m},{n},{k}): {found}") - - # Second call: cache-hit path (no benchmarking) - C2 = aiter.hipb_mm(A, B.t(), solution_index=-1) - torch.accelerator.synchronize() - - ref = _reference(A, B) - assert torch.allclose(C.float(), ref.float(), atol=0.05, rtol=0.05) - assert torch.allclose(C2.float(), ref.float(), atol=0.05, rtol=0.05) - - print("[PASS] test_hip_online_tuning_csv_populated") - - -# --------------------------------------------------------------------------- -# Example 4: vLLM integration via VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING -# --------------------------------------------------------------------------- - - -def test_vllm_env_var_sets_hip_online_tuning(): - """ - Demonstrates that vLLM's ROCm AITER knobs that rely on hipBLASLt online - tuning propagate to HIP_ONLINE_TUNING=1 at the C++ level. - - In normal usage the env var must be set *before* the process starts - (because vllm/platforms/rocm.py reads and forwards it at import time, - before hipBLASLt is initialised). This test checks the forwarding - logic in isolation. - - To exercise the full end-to-end path, start vLLM like this:: - - VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING=1 vllm serve - - or:: - - VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING=1 python -c " - import vllm - llm = vllm.LLM('meta-llama/Llama-3.1-8B') - out = llm.generate(['Hello']) - print(out[0].outputs[0].text) - " - - The first decode requests for each unique (M, N, K) shape will trigger - online tuning (≈ a few seconds). Results persist in hip_online_tuning_res.csv, - so subsequent runs are instant. - """ - # vllm/platforms/rocm.py executes this at import time: - # if envs.VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING: - # os.environ["HIP_ONLINE_TUNING"] = "1" - # - # We simulate that forwarding here and confirm HIP_ONLINE_TUNING is set. - - import vllm.envs as envs - - if envs.VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING: - assert os.environ.get("HIP_ONLINE_TUNING") == "1", ( - "VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING was set but " - "HIP_ONLINE_TUNING was not forwarded to the environment. " - "Make sure vllm.platforms.rocm is imported before hipBLASLt " - "is initialised." - ) - print("\n ROCm AITER HIP tuning env var → HIP_ONLINE_TUNING=1 ✓") - elif envs.VLLM_ROCM_AITER_FORCE_HIPBMM_LINEAR: - assert os.environ.get("HIP_ONLINE_TUNING") != "1", ( - "VLLM_ROCM_AITER_FORCE_HIPBMM_LINEAR should not enable " - "HIP_ONLINE_TUNING by default." - ) - print("\n Force hipb_mm linear is set without HIP online tuning ✓") - else: - print( - "\n No ROCm AITER HIP tuning env var is set " - "(HIP_ONLINE_TUNING will not be enabled via vLLM)" - ) - - print("[PASS] test_vllm_env_var_sets_hip_online_tuning") - - -# --------------------------------------------------------------------------- -# Example 5: FP8 row-wise scaled GEMM with hipb_mm -# --------------------------------------------------------------------------- - - -def test_hipb_mm_fp8_rowwise(): - """ - Demonstrates hipb_mm with FP8 inputs and row-wise scaling, - which is used for quantised inference on MI300 (gfx942) and - MI350 (gfx950). - - Row-wise scaling requires hipBLASLt >= 1.0 (ROCm 7.0+). - Online tuning also applies here when HIP_ONLINE_TUNING=1 and N <= 512. - """ - import aiter - - _init_hipblas() - - try: - fp8_dtype = torch.float8_e4m3fnuz # MI300 native FP8 - except AttributeError: - pytest.skip("torch.float8_e4m3fnuz not available") - - m, n, k = 4, 512, 4096 - - # Quantise inputs to FP8 - A_bf16 = torch.randn(m, k, dtype=torch.bfloat16, device="cuda") - B_bf16 = torch.randn(n, k, dtype=torch.bfloat16, device="cuda") - - A_fp8, x_scale = aiter.pertoken_quant(A_bf16, quant_dtype=fp8_dtype) - B_fp8, w_scale = aiter.pertoken_quant(B_bf16, quant_dtype=fp8_dtype) - - # x_scale shape: [M, 1], w_scale shape: [N, 1] - # hipb_mm expects scaleB to be transposed → [1, N] - C = aiter.hipb_mm( - A_fp8, - B_fp8.t(), # [K, N] - solution_index=-1, - out_dtype=torch.bfloat16, - scaleA=x_scale, - scaleB=w_scale.t(), # [1, N] - ) - - assert C.shape == (m, n), f"Expected ({m},{n}), got {C.shape}" - assert C.dtype == torch.bfloat16 - - # Reference: dequantise the quantized FP8 inputs using their row-wise scales, - # then run the same linear algebra in fp32. - ref = torch.nn.functional.linear( - A_fp8.float() * x_scale.float(), - B_fp8.float() * w_scale.float(), - ).bfloat16() - assert torch.allclose(C.float(), ref.float(), atol=0.5, rtol=0.1), ( - "FP8 result deviates too far from dequantized FP8 reference" - ) - - print(f"\n FP8 rowwise ({m},{n},{k}) ✓ out={C.shape}") - print("[PASS] test_hipb_mm_fp8_rowwise") - - -# --------------------------------------------------------------------------- -# Example 6: vLLM kernel gating for AiterHipbMMPerTokenFp8ScaledMMLinearKernel -# --------------------------------------------------------------------------- - - -def test_aiter_hipb_mm_kernel_requires_force_flag(monkeypatch: pytest.MonkeyPatch): - from vllm._aiter_ops import rocm_aiter_ops - from vllm.model_executor.kernels.linear.scaled_mm.aiter import ( - AiterHipbMMPerTokenFp8ScaledMMLinearKernel, - ) - - monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") - monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "1") - monkeypatch.delenv("VLLM_ROCM_AITER_FORCE_HIPBMM_LINEAR", raising=False) - rocm_aiter_ops.refresh_env_variables() - - try: - is_supported, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.is_supported() - assert not is_supported - assert reason is not None - assert "VLLM_ROCM_AITER_FORCE_HIPBMM_LINEAR=1" in reason - print("[PASS] test_aiter_hipb_mm_kernel_requires_force_flag") - finally: - monkeypatch.undo() - rocm_aiter_ops.refresh_env_variables() - - -# --------------------------------------------------------------------------- -# CSV helpers (used by test_hip_online_tuning_csv_populated) -# --------------------------------------------------------------------------- - - -def _csv_row_matches_shape(row: dict, m: int, n: int, k: int) -> bool: - """hipBLASLt tuning CSVs may store the logical M/N order swapped.""" - row_m = int(row.get("m", -1)) - row_n = int(row.get("n", -1)) - row_k = int(row.get("k", -1)) - return row_k == k and ((row_m == m and row_n == n) or (row_m == n and row_n == m)) - - -def _find_csv_row(path: str, m: int, n: int, k: int) -> dict | None: - """Return the first CSV row whose m/n/k fields match, or None.""" - if not os.path.exists(path): - return None - import csv as _csv - - with open(path, newline="") as f: - reader = _csv.DictReader(f, skipinitialspace=True) - for row in reader: - try: - if _csv_row_matches_shape(row, m, n, k): - return dict(row) - except (ValueError, KeyError): - continue - return None - - -def _remove_csv_row(path: str, m: int, n: int, k: int) -> None: - """Remove rows matching (m, n, k) from the CSV (for test repeatability).""" - if not os.path.exists(path): - return - import csv as _csv - - rows = [] - with open(path, newline="") as f: - reader = _csv.DictReader(f, skipinitialspace=True) - fieldnames = reader.fieldnames - for row in reader: - try: - if not _csv_row_matches_shape(row, m, n, k): - rows.append(row) - except (ValueError, KeyError): - rows.append(row) - if fieldnames: - with open(path, "w", newline="") as f: - writer = _csv.DictWriter(f, fieldnames=fieldnames) - writer.writeheader() - writer.writerows(rows) - - -# --------------------------------------------------------------------------- -# Standalone entry-point (run without pytest) -# --------------------------------------------------------------------------- - -if __name__ == "__main__": - if not (is_rocm and aiter_available): - print("ERROR: requires ROCm platform with aiter installed") - raise SystemExit(1) - - print("=" * 60) - print("hipBLASLt Online Tuning Demo") - print("=" * 60) - print(f"HIP_ONLINE_TUNING = {os.environ.get('HIP_ONLINE_TUNING', '(not set)')}") - - try: - import vllm.envs as envs - - print( - f"VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING = " - f"{envs.VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING}" - ) - except ImportError: - pass - - print() - - test_hipb_mm_heuristic() - print() - test_hipb_mm_explicit_solution() - print() - test_hip_online_tuning_csv_populated() - print() - test_vllm_env_var_sets_hip_online_tuning() - print() - test_hipb_mm_fp8_rowwise() - print() - test_aiter_hipb_mm_kernel_requires_force_flag(pytest.MonkeyPatch()) - print() - - - print() - print("=" * 60) - print("All tests passed.") - print("=" * 60) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index ba51d1f76bc4..b91d9c28e352 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -597,7 +597,7 @@ def _rocm_aiter_hipb_mm_fp8_impl( scale_b = Bs.t().contiguous() if Bs.ndim > 1 else Bs return hipb_mm( A, - B.t().contiguous(), + B.t(), solution_index=-1, bias=bias, out_dtype=output_dtype, diff --git a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py index da107e44365d..b3c26e0a8795 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py @@ -292,6 +292,7 @@ def apply_scaled_mm( bias: torch.Tensor | None, output_shape: list, ) -> torch.Tensor: + output_shape[-1] = B.shape[0] return rocm_aiter_ops.hipb_mm_fp8(A, B, As, Bs, bias, out_dtype).view( *output_shape ) From 50a8a1db805bf7bd9496b2710806cdbf7c145c61 Mon Sep 17 00:00:00 2001 From: hanlin12 Date: Tue, 28 Apr 2026 08:00:34 +0000 Subject: [PATCH 11/19] fix the env name and logic of aiter hipblaslt gemm and online tuning Signed-off-by: hanlin12 --- vllm/_aiter_ops.py | 9 +++++++++ vllm/envs.py | 13 +++---------- .../kernels/linear/scaled_mm/aiter.py | 12 ++++++------ vllm/platforms/rocm.py | 4 ++-- 4 files changed, 20 insertions(+), 18 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index b91d9c28e352..0ba024740d38 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -1229,6 +1229,7 @@ class rocm_aiter_ops: _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM _FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM # TODO: Consolidate under _LINEAR_ENABLED + _HIP_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM # TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE _TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE @@ -1255,6 +1256,7 @@ def refresh_env_variables(cls): cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM cls._FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM + cls._HIP_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS @@ -1386,6 +1388,13 @@ def is_fp4bmm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._FP4BMM_ENABLED and on_gfx950() + @classmethod + @if_aiter_supported + def is_hip_fp8bmm_enabled(cls) -> bool: + from vllm.platforms.rocm import on_mi3xx + + return cls._AITER_ENABLED and cls._HIP_FP8BMM_ENABLED and on_mi3xx() + @classmethod @if_aiter_supported def is_asm_fp4_gemm_dynamic_quant_enabled(cls) -> bool: diff --git a/vllm/envs.py b/vllm/envs.py index d778b3ac31e0..b5671615b784 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -111,7 +111,7 @@ VLLM_ROCM_USE_AITER: bool = False VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False VLLM_ROCM_USE_AITER_LINEAR: bool = True - VLLM_ROCM_AITER_FORCE_HIPBMM_LINEAR: bool = False + VLLM_ROCM_USE_AITER_LINEAR_HIPBMM: bool = False VLLM_ROCM_USE_AITER_MOE: bool = True VLLM_ROCM_USE_AITER_RMSNORM: bool = True VLLM_ROCM_USE_AITER_MLA: bool = True @@ -255,7 +255,6 @@ VLLM_ELASTIC_EP_SCALE_UP_LAUNCH: bool = False VLLM_ELASTIC_EP_DRAIN_REQUESTS: bool = False VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS: bool = False - VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING: bool = False VLLM_NIXL_EP_MAX_NUM_RANKS: int = 32 VLLM_XPU_ENABLE_XPU_GRAPH: bool = False VLLM_LORA_ENABLE_DUAL_STREAM: bool = False @@ -992,8 +991,8 @@ def _get_or_set_default() -> str: "VLLM_ROCM_USE_AITER_LINEAR": lambda: ( os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in ("true", "1") ), - "VLLM_ROCM_AITER_FORCE_HIPBMM_LINEAR": lambda: ( - os.getenv("VLLM_ROCM_AITER_FORCE_HIPBMM_LINEAR", "False").lower() + "VLLM_ROCM_USE_AITER_LINEAR_HIPBMM": lambda: ( + os.getenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", "False").lower() in ("true", "1") ), # Whether to use aiter moe ops. @@ -1101,12 +1100,6 @@ def _get_or_set_default() -> str: "VLLM_LOG_BATCHSIZE_INTERVAL": lambda: float( os.getenv("VLLM_LOG_BATCHSIZE_INTERVAL", "-1") ), - # Whether to use HIP online tuning for ROCm - # By default is disabled. - "VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING", "False").lower() - in ("true", "1") - ), "VLLM_DISABLE_COMPILE_CACHE": disable_compile_cache, # If set to "0", disable LayerName opaque type for layer_name # parameters in custom ops. Defaults to enabled on torch >= 2.11. diff --git a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py index b3c26e0a8795..9b298aa6a8d5 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py @@ -218,17 +218,17 @@ class AiterHipbMMPerTokenFp8ScaledMMLinearKernel(FP8ScaledMMLinearKernel): def is_supported( cls, compute_capability: int | None = None ) -> tuple[bool, str | None]: - if not envs.VLLM_ROCM_AITER_FORCE_HIPBMM_LINEAR: - return False, "requires VLLM_ROCM_AITER_FORCE_HIPBMM_LINEAR=1." - if not current_platform.is_rocm(): return False, "requires ROCm." - if not rocm_aiter_ops.is_linear_fp8_enabled(): + + if not rocm_aiter_ops.is_hip_fp8bmm_enabled(): + return False, "requires VLLM_ROCM_USE_AITER_LINEAR_HIPBMM =1." + + if not rocm_aiter_ops.is_hip_fp8bmm_enabled(): return ( False, "requires setting `VLLM_ROCM_USE_AITER=1` " - "and `VLLM_ROCM_USE_AITER_LINEAR=1`. " - "`VLLM_ROCM_USE_AITER_LINEAR` default is True.", + "and `VLLM_ROCM_USE_AITER_LINEAR_HIPBMM=1`. ", ) try: import aiter # noqa: F401 diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 044e5d6a5ca4..5fb950ebefe9 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -129,8 +129,8 @@ def _sync_hip_cuda_env_vars(): _sync_hip_cuda_env_vars() # Enable HIP online tuning early, before hipBLASLt initializes. -# Forcing the hipb_mm linear kernel does not imply online tuning. -if envs.VLLM_ROCM_USE_AITER_HIP_ONLINE_TUNING: +# Turn on hipBLASLt online tuning if use AITER hipBLASLt GEMM. +if envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM: os.environ["HIP_ONLINE_TUNING"] = "1" # AMDSMI utils From c4ff925510adb1d87be63eab94819608806911a9 Mon Sep 17 00:00:00 2001 From: hanlin12 Date: Tue, 28 Apr 2026 17:18:48 +0000 Subject: [PATCH 12/19] ensure VLLM_ROCM_USE_AITER_LINEAR_HIPBMM working Signed-off-by: hanlin12 --- vllm/_aiter_ops.py | 2 +- vllm/platforms/rocm.py | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 0ba024740d38..479aa80e07af 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -617,7 +617,7 @@ def _rocm_aiter_hipb_mm_fp8_fake( output_dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: m = A.shape[0] - n = B.shape[1] + n = B.shape[0] return torch.empty(m, n, dtype=output_dtype, device=A.device) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 5fb950ebefe9..1b675b6b59db 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -132,6 +132,13 @@ def _sync_hip_cuda_env_vars(): # Turn on hipBLASLt online tuning if use AITER hipBLASLt GEMM. if envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM: os.environ["HIP_ONLINE_TUNING"] = "1" + # hipBMM requires aiter enabled, and disabling non-hipBMM linear + # avoids the +quant_fp8 custom op that triggers fuse_norm_quant, + # which fails with Float tensors. Rmsnorm requires composable_kernel + # submodule which may not be initialized, so disable it too. + os.environ["VLLM_ROCM_USE_AITER"] = "1" + os.environ["VLLM_ROCM_USE_AITER_LINEAR"] = "0" + os.environ["VLLM_ROCM_USE_AITER_RMSNORM"] = "0" # AMDSMI utils # Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`, From 1ebc71e6a3f5714f887d927df93825b5314a16e4 Mon Sep 17 00:00:00 2001 From: hanlin12 Date: Fri, 8 May 2026 07:59:38 +0000 Subject: [PATCH 13/19] Change the conditions of hipblaslt online tuning Signed-off-by: hanlin12 --- vllm/_aiter_ops.py | 2 +- .../kernels/linear/scaled_mm/aiter.py | 7 ++++- vllm/platforms/rocm.py | 26 ++++++++++--------- 3 files changed, 21 insertions(+), 14 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 7db328d65c46..fbb092459879 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -1409,7 +1409,7 @@ def is_fp4bmm_enabled(cls) -> bool: def is_hip_fp8bmm_enabled(cls) -> bool: from vllm.platforms.rocm import on_mi3xx - return cls._AITER_ENABLED and cls._HIP_FP8BMM_ENABLED and on_mi3xx() + return cls.is_linear_enabled() and on_mi3xx() @classmethod @if_aiter_supported diff --git a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py index 188ae3f1853b..f39abf50f838 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py @@ -222,7 +222,12 @@ def is_supported( return False, "requires ROCm." if not rocm_aiter_ops.is_hip_fp8bmm_enabled(): - return False, "requires VLLM_ROCM_USE_AITER_LINEAR_HIPBMM =1." + return ( + False, + "requires setting `VLLM_ROCM_USE_AITER=1` " + "and `VLLM_ROCM_USE_AITER_LINEAR=1`. ", + "and `VLLM_ROCM_USE_AITER_LINEAR_HIPBMM=1`. ", + ) if not rocm_aiter_ops.is_hip_fp8bmm_enabled(): return ( diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 879162f947fb..8ece4ebdc76e 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -129,17 +129,7 @@ def _sync_hip_cuda_env_vars(): # Sync at import time - catches misconfigurations from process start. _sync_hip_cuda_env_vars() -# Enable HIP online tuning early, before hipBLASLt initializes. -# Turn on hipBLASLt online tuning if use AITER hipBLASLt GEMM. -if envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM: - os.environ["HIP_ONLINE_TUNING"] = "1" - # hipBMM requires aiter enabled, and disabling non-hipBMM linear - # avoids the +quant_fp8 custom op that triggers fuse_norm_quant, - # which fails with Float tensors. Rmsnorm requires composable_kernel - # submodule which may not be initialized, so disable it too. - os.environ["VLLM_ROCM_USE_AITER"] = "1" - os.environ["VLLM_ROCM_USE_AITER_LINEAR"] = "0" - os.environ["VLLM_ROCM_USE_AITER_RMSNORM"] = "0" + # AMDSMI utils # Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`, @@ -303,7 +293,19 @@ def on_gfx942() -> bool: def on_gfx950() -> bool: return _ON_GFX950 - +# Enable HIP online tuning early, before hipBLASLt initializes. +# Turn on hipBLASLt online tuning if use AITER hipBLASLt GEMM. +if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM and on_mi3xx(): + os.environ["HIP_ONLINE_TUNING"] = "1" + # hipBMM requires aiter enabled, and disabling non-hipBMM linear + # avoids the +quant_fp8 custom op that triggers fuse_norm_quant, + # which fails with Float tensors. Rmsnorm requires composable_kernel + # submodule which may not be initialized, so disable it too. + + #os.environ["VLLM_ROCM_USE_AITER"] = "1" + #os.environ["VLLM_ROCM_USE_AITER_LINEAR"] = "0" + #os.environ["VLLM_ROCM_USE_AITER_RMSNORM"] = "0" + @cache def use_rocm_custom_paged_attention( qtype: torch.dtype, From fde58786b2620bfc2fdb5248440711a34180e52e Mon Sep 17 00:00:00 2001 From: hanlin12 Date: Sat, 9 May 2026 07:49:25 +0000 Subject: [PATCH 14/19] Resolve the condition of hipBLASLt online tuning Signed-off-by: hanlin12 --- vllm/_aiter_ops.py | 7 ++-- .../kernels/linear/scaled_mm/aiter.py | 33 ++++++++++--------- vllm/platforms/rocm.py | 8 ----- 3 files changed, 21 insertions(+), 27 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index fbb092459879..21466110c3c9 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -618,15 +618,14 @@ def _rocm_aiter_hipb_mm_fp8_impl( from aiter import hipb_mm _ensure_hipb_mm_extension_initialized() - scale_b = Bs.t().contiguous() if Bs.ndim > 1 else Bs return hipb_mm( A, - B.t(), + B, solution_index=-1, bias=bias, out_dtype=output_dtype, scaleA=As, - scaleB=scale_b, + scaleB=Bs, scaleOut=None, bpreshuffle=True, ) @@ -641,7 +640,7 @@ def _rocm_aiter_hipb_mm_fp8_fake( output_dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: m = A.shape[0] - n = B.shape[0] + n = B.shape[1] return torch.empty(m, n, dtype=output_dtype, device=A.device) diff --git a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py index f39abf50f838..a78cf96ce1f3 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py @@ -228,13 +228,6 @@ def is_supported( "and `VLLM_ROCM_USE_AITER_LINEAR=1`. ", "and `VLLM_ROCM_USE_AITER_LINEAR_HIPBMM=1`. ", ) - - if not rocm_aiter_ops.is_hip_fp8bmm_enabled(): - return ( - False, - "requires setting `VLLM_ROCM_USE_AITER=1` " - "and `VLLM_ROCM_USE_AITER_LINEAR_HIPBMM=1`. ", - ) try: import aiter # noqa: F401 except Exception: @@ -274,18 +267,28 @@ def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | Non return True, None def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - w_name, *_ = self.layer_param_names - w, *_ = self._get_layer_params(layer) - + w_name, w_s_name, *_ = self.layer_param_names + w, w_s, *_ = self._get_layer_params(layer) + + # Pre-apply the transposes that used to live in + # _rocm_aiter_hipb_mm_fp8_impl so the kernel can consume B/Bs directly. + # The `.t()` on the shuffled weight is kept as a non-contiguous view — + # materializing it with `.contiguous()` would re-arrange the bytes and + # break the `bpreshuffle` layout. + shuffled_w = rocm_aiter_ops.shuffle_weight(w.t().contiguous()) replace_parameter( layer, w_name, - torch.nn.Parameter( - rocm_aiter_ops.shuffle_weight(w.t().contiguous()).data, - requires_grad=False, - ), + torch.nn.Parameter(shuffled_w.t(), requires_grad=False), ) + if w_s.ndim > 1: + replace_parameter( + layer, + w_s_name, + torch.nn.Parameter(w_s.t().contiguous(), requires_grad=False), + ) + def apply_scaled_mm( self, *, @@ -297,7 +300,7 @@ def apply_scaled_mm( bias: torch.Tensor | None, output_shape: list, ) -> torch.Tensor: - output_shape[-1] = B.shape[0] + output_shape[-1] = B.shape[1] return rocm_aiter_ops.hipb_mm_fp8(A, B, As, Bs, bias, out_dtype).view( *output_shape ) diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 8ece4ebdc76e..5ac8c85ea30a 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -297,14 +297,6 @@ def on_gfx950() -> bool: # Turn on hipBLASLt online tuning if use AITER hipBLASLt GEMM. if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM and on_mi3xx(): os.environ["HIP_ONLINE_TUNING"] = "1" - # hipBMM requires aiter enabled, and disabling non-hipBMM linear - # avoids the +quant_fp8 custom op that triggers fuse_norm_quant, - # which fails with Float tensors. Rmsnorm requires composable_kernel - # submodule which may not be initialized, so disable it too. - - #os.environ["VLLM_ROCM_USE_AITER"] = "1" - #os.environ["VLLM_ROCM_USE_AITER_LINEAR"] = "0" - #os.environ["VLLM_ROCM_USE_AITER_RMSNORM"] = "0" @cache def use_rocm_custom_paged_attention( From d0013851e5410835a7e3267fbe48fae838fdde5a Mon Sep 17 00:00:00 2001 From: hanlin12 Date: Wed, 13 May 2026 09:51:42 +0000 Subject: [PATCH 15/19] fix some variable name Signed-off-by: hanlin12 --- vllm/_aiter_ops.py | 7 +++---- vllm/model_executor/kernels/linear/scaled_mm/aiter.py | 4 ++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index ac4cbbe5f488..0cb229350c22 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -1257,7 +1257,6 @@ class rocm_aiter_ops: _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM _FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM # TODO: Consolidate under _LINEAR_ENABLED - _HIP_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM # TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE _TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE @@ -1289,7 +1288,7 @@ def refresh_env_variables(cls): cls._TRITON_UNIFIED_ATTN_ENABLED = envs.VLLM_ROCM_USE_AITER_UNIFIED_ATTENTION cls._FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM cls._FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM - cls._HIP_FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM + cls._LINEAR_HIPBMM_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM cls._FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM cls._TRITON_ROTARY_EMBED = envs.VLLM_ROCM_USE_AITER_TRITON_ROPE cls._MOE_SHARED_EXPERTS_ENABLED = envs.VLLM_ROCM_USE_AITER_FUSION_SHARED_EXPERTS @@ -1464,10 +1463,10 @@ def is_fp4bmm_enabled(cls) -> bool: @classmethod @if_aiter_supported - def is_hip_fp8bmm_enabled(cls) -> bool: + def is_linear_hipbmm_enabled(cls) -> bool: from vllm.platforms.rocm import on_mi3xx - return cls.is_linear_enabled() and on_mi3xx() + return cls.is_linear_enabled() and on_mi3xx() and cls._LINEAR_HIPBMM_ENABLED @classmethod @if_aiter_supported diff --git a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py index a78cf96ce1f3..66cfa748ebd0 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py @@ -221,11 +221,11 @@ def is_supported( if not current_platform.is_rocm(): return False, "requires ROCm." - if not rocm_aiter_ops.is_hip_fp8bmm_enabled(): + if not rocm_aiter_ops.is_linear_hipbmm_enabled(): return ( False, "requires setting `VLLM_ROCM_USE_AITER=1` " - "and `VLLM_ROCM_USE_AITER_LINEAR=1`. ", + "and `VLLM_ROCM_USE_AITER_LINEAR=1` " "and `VLLM_ROCM_USE_AITER_LINEAR_HIPBMM=1`. ", ) try: From bdb9d332e26937c8e99265781a12bbfb4d36066e Mon Sep 17 00:00:00 2001 From: hanlin12 Date: Fri, 15 May 2026 02:53:34 +0000 Subject: [PATCH 16/19] fix missing line in aiter_ops Signed-off-by: hanlin12 --- vllm/_aiter_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/_aiter_ops.py b/vllm/_aiter_ops.py index 0cb229350c22..e69fe6a10047 100644 --- a/vllm/_aiter_ops.py +++ b/vllm/_aiter_ops.py @@ -1256,6 +1256,7 @@ class rocm_aiter_ops: # TODO: Consolidate under _LINEAR_ENABLED _FP8BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP8BMM _FP4BMM_ENABLED = envs.VLLM_ROCM_USE_AITER_FP4BMM + _LINEAR_HIPBMM_ENABLED = envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM # TODO: Consolidate under _LINEAR_ENABLED _FP4_GEMM_DYNAMIC_QUANT_ASM = envs.VLLM_ROCM_USE_AITER_FP4_ASM_GEMM # TODO: Consolidate under VLLM_ROCM_USE_AITER_ROPE From 6b52d9794274563c98ba0793f832304a878bbd88 Mon Sep 17 00:00:00 2001 From: hanlin12 Date: Thu, 28 May 2026 09:11:53 +0000 Subject: [PATCH 17/19] fix pre-commit Signed-off-by: hanlin12 --- vllm/envs.py | 3 +-- .../model_executor/kernels/linear/scaled_mm/aiter.py | 5 ++--- vllm/platforms/rocm.py | 12 +++++++++--- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index 33906fa43b78..17a7d2215379 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1101,8 +1101,7 @@ def _resolve_rust_frontend_path() -> str | None: os.getenv("VLLM_ROCM_USE_AITER_LINEAR", "True").lower() in ("true", "1") ), "VLLM_ROCM_USE_AITER_LINEAR_HIPBMM": lambda: ( - os.getenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", "False").lower() - in ("true", "1") + os.getenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", "False").lower() in ("true", "1") ), # Whether to use aiter moe ops. # By default is enabled. diff --git a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py index 66cfa748ebd0..7ce4705510a2 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py @@ -4,7 +4,6 @@ import torch -import vllm.envs as envs from vllm import _custom_ops as ops from vllm._aiter_ops import ( rocm_aiter_ops, @@ -247,7 +246,6 @@ def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | Non if c.weight_shape is None: return False, "weight_shape is required for Aiter kernels" N, K = c.weight_shape - fp8_dtype = current_platform.fp8_dtype() if c.out_dtype is not torch.bfloat16: return False, "requires bfloat16 output dtype." @@ -261,7 +259,8 @@ def can_implement(cls, c: FP8ScaledMMLinearLayerConfig) -> tuple[bool, str | Non if not (N >= 16 and N % 16 == 0 and K % 16 == 0): return ( False, - f"requires N >= 16 and both N and K divisible by 16, received N={N} and K={K}.", + "requires N >= 16 and both N and K divisible by 16, " + f"received N={N} and K={K}.", ) return True, None diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 221e75f11cbc..11c4d3bd891b 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -135,7 +135,6 @@ def _sync_hip_cuda_env_vars(): _sync_hip_cuda_env_vars() - # AMDSMI utils # Note that NVML is not affected by `{CUDA/HIP}_VISIBLE_DEVICES`, # all the related functions work on real physical device ids. @@ -298,11 +297,18 @@ def on_gfx942() -> bool: def on_gfx950() -> bool: return _ON_GFX950 + # Enable HIP online tuning early, before hipBLASLt initializes. # Turn on hipBLASLt online tuning if use AITER hipBLASLt GEMM. -if envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER_LINEAR and envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM and on_mi3xx(): +if ( + envs.VLLM_ROCM_USE_AITER + and envs.VLLM_ROCM_USE_AITER_LINEAR + and envs.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM + and on_mi3xx() +): os.environ["HIP_ONLINE_TUNING"] = "1" - + + @cache def use_rocm_custom_paged_attention( qtype: torch.dtype, From 6c0d85e74d9d735f0ddb68d9c8ba274408d60d93 Mon Sep 17 00:00:00 2001 From: hanlin12 Date: Fri, 29 May 2026 02:49:07 +0000 Subject: [PATCH 18/19] Add accuracy unit-test of Aiter hipBlaslt Signed-off-by: hanlin12 --- .../aiter/test_aiter_hipb_mm_linear_kernel.py | 407 ++++++++++++++++++ .../kernels/linear/scaled_mm/aiter.py | 6 +- 2 files changed, 410 insertions(+), 3 deletions(-) create mode 100644 tests/rocm/aiter/test_aiter_hipb_mm_linear_kernel.py diff --git a/tests/rocm/aiter/test_aiter_hipb_mm_linear_kernel.py b/tests/rocm/aiter/test_aiter_hipb_mm_linear_kernel.py new file mode 100644 index 000000000000..cb19f993389c --- /dev/null +++ b/tests/rocm/aiter/test_aiter_hipb_mm_linear_kernel.py @@ -0,0 +1,407 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import csv +import importlib.util +import importlib +import os + +import pytest +import torch + +from tests.utils import TestFP8Layer +from vllm._aiter_ops import rocm_aiter_ops +from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import ( + FP8ScaledMMLinearLayerConfig, +) +from vllm.model_executor.kernels.linear.scaled_mm.aiter import ( + AiterHipbMMPerTokenFp8ScaledMMLinearKernel, +) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + kFp8DynamicTokenSym, + kFp8StaticChannelSym, + kFp8StaticTensorSym, +) +from vllm.platforms import current_platform + +aiter_available = importlib.util.find_spec("aiter") is not None + +pytestmark = [ + pytest.mark.skipif( + not ( + current_platform.is_rocm() + and current_platform.supports_fp8() + and aiter_available + ), + reason="Requires ROCm + FP8 support + aiter", + ), + pytest.mark.usefixtures("default_vllm_config"), +] + + +@pytest.fixture +def enable_hipb_mm_kernel(monkeypatch: pytest.MonkeyPatch): + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "1") + monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", "1") + rocm_aiter_ops.refresh_env_variables() + yield + rocm_aiter_ops.refresh_env_variables() + + +def _make_config( + *, + weight_quant_key=kFp8StaticChannelSym, + out_dtype: torch.dtype = torch.bfloat16, + weight_shape: tuple[int, int] = (512, 4096), +) -> FP8ScaledMMLinearLayerConfig: + return FP8ScaledMMLinearLayerConfig( + weight_quant_key=weight_quant_key, + activation_quant_key=kFp8DynamicTokenSym, + weight_shape=weight_shape, + input_dtype=torch.bfloat16, + out_dtype=out_dtype, + ) + + +def _find_csv_row(path: str, m: int, n: int, k: int) -> dict | None: + if not os.path.exists(path): + return None + + with open(path, newline="") as f: + reader = csv.DictReader(f, skipinitialspace=True) + for row in reader: + try: + if ( + int(row.get("m", -1)) == m + and int(row.get("n", -1)) == n + and int(row.get("k", -1)) == k + ): + return dict(row) + except (TypeError, ValueError): + continue + return None + + +def _skip_if_no_hipb_mm_solution(exc: RuntimeError) -> None: + if "hipblasLtMatmulAlgoGetHeuristic found 0 valid solutions" in str(exc): + pytest.skip( + "hipb_mm bpreshuffle path has no valid hipBLASLt solution on " + "this ROCm stack." + ) + + +def _check_bpreshuffle_runtime_support(weight_shape: tuple[int, int], num_tokens: int): + import aiter + from aiter.ops.shuffle import shuffle_weight + + x = torch.randn( + num_tokens, weight_shape[1], dtype=torch.bfloat16, device="cuda" + ) + w = torch.randn(weight_shape, dtype=torch.bfloat16, device="cuda") + + aiter.hipb_create_extension() + x_q, x_scale = aiter.pertoken_quant(x, quant_dtype=current_platform.fp8_dtype()) + w_q, w_scale = aiter.pertoken_quant(w, quant_dtype=current_platform.fp8_dtype()) + + try: + aiter.hipb_mm( + x_q, + shuffle_weight(w_q, layout=(16, 16)).t(), + solution_index=-1, + out_dtype=torch.bfloat16, + scaleA=x_scale, + scaleB=w_scale.t().contiguous(), + scaleOut=None, + bpreshuffle=True, + ) + except RuntimeError as exc: + _skip_if_no_hipb_mm_solution(exc) + raise + + +def test_hipb_mm_kernel_requires_hipbmm_flag(monkeypatch: pytest.MonkeyPatch): + # The kernel rejects when `is_hip_fp8bmm_enabled()` is False. That helper + # requires AITER + AITER_LINEAR + MI3xx, so dropping AITER_LINEAR exercises + # the rejection branch. + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "0") + monkeypatch.delenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", raising=False) + rocm_aiter_ops.refresh_env_variables() + + is_supported, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.is_supported() + + assert not is_supported + assert reason == ( + "requires setting `VLLM_ROCM_USE_AITER=1`, " + "`VLLM_ROCM_USE_AITER_LINEAR=1`, " + "and `VLLM_ROCM_USE_AITER_LINEAR_HIPBMM=1`." + ) + + +def test_hipb_mm_flag_enables_hip_online_tuning( + monkeypatch: pytest.MonkeyPatch, +): + import vllm.envs as envs_mod + import vllm.platforms.rocm as rocm_mod + + # The rocm.py gate requires all three AITER flags (and MI3xx) to auto-set + # HIP_ONLINE_TUNING. + monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1") + monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR", "1") + monkeypatch.setenv("VLLM_ROCM_USE_AITER_LINEAR_HIPBMM", "1") + + try: + importlib.reload(envs_mod) + importlib.reload(rocm_mod) + assert envs_mod.VLLM_ROCM_USE_AITER + assert envs_mod.VLLM_ROCM_USE_AITER_LINEAR + assert envs_mod.VLLM_ROCM_USE_AITER_LINEAR_HIPBMM + assert os.environ.get("HIP_ONLINE_TUNING") == "1" + finally: + monkeypatch.undo() + os.environ.pop("HIP_ONLINE_TUNING", None) + importlib.reload(envs_mod) + importlib.reload(rocm_mod) + rocm_aiter_ops.refresh_env_variables() + + +def test_hipb_mm_kernel_can_implement_success(enable_hipb_mm_kernel): + can_implement, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.can_implement( + _make_config() + ) + + assert can_implement + assert reason is None + + +@pytest.mark.parametrize( + ("config", "expected_reason"), + [ + ( + _make_config(weight_quant_key=kFp8StaticTensorSym), + "requires per token activation scales and per channel weight scales.", + ), + ( + _make_config(out_dtype=torch.float16), + "requires bfloat16 output dtype.", + ), + ( + _make_config(weight_shape=(8, 4090)), + "requires N >= 16 and both N and K divisible by 16, received N=8 and K=4090.", + ), + ], +) +def test_hipb_mm_kernel_can_implement_rejects_unsupported_configs( + enable_hipb_mm_kernel, + config: FP8ScaledMMLinearLayerConfig, + expected_reason: str, +): + can_implement, reason = AiterHipbMMPerTokenFp8ScaledMMLinearKernel.can_implement( + config + ) + + assert not can_implement + assert reason == expected_reason + + +def test_hipb_mm_kernel_process_weights_after_loading_shuffles_weights( + enable_hipb_mm_kernel, +): + weight_shape = (512, 4096) + kernel = AiterHipbMMPerTokenFp8ScaledMMLinearKernel( + _make_config(weight_shape=weight_shape), + layer_param_names=("weight", "weight_scale", "input_scale", "input_scale_ub"), + ) + + layer = torch.nn.Module() + layer.weight = torch.nn.Parameter( + torch.rand(weight_shape, device="cuda").to(current_platform.fp8_dtype()).t(), + requires_grad=False, + ) + layer.weight_scale = torch.nn.Parameter( + torch.rand((weight_shape[0], 1), dtype=torch.float32, device="cuda"), + requires_grad=False, + ) + layer.input_scale = None + layer.input_scale_ub = None + + original_weight = layer.weight.detach().clone() + original_weight_scale = layer.weight_scale.detach().clone() + + kernel.process_weights_after_loading(layer) + + # process_weights_after_loading now pre-applies the transposes that used + # to live in _rocm_aiter_hipb_mm_fp8_impl, so the stored weight is the + # shuffled tensor with a trailing `.t()` view, and the stored weight scale + # is its transposed-contiguous form. + expected_weight = rocm_aiter_ops.shuffle_weight( + original_weight.t().contiguous() + ).t() + torch.testing.assert_close(layer.weight, expected_weight) + + expected_weight_scale = original_weight_scale.t().contiguous() + torch.testing.assert_close(layer.weight_scale, expected_weight_scale) + + +def test_hipb_mm_kernel_forward_matches_raw_aiter_hipb_mm(enable_hipb_mm_kernel): + import aiter + + weight_shape = (512, 4096) + _check_bpreshuffle_runtime_support(weight_shape, num_tokens=32) + + layer = TestFP8Layer( + weight_shape=weight_shape, + activation_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8StaticChannelSym, + input_dtype=torch.bfloat16, + out_dtype=torch.bfloat16, + device=torch.device("cuda"), + force_kernel=AiterHipbMMPerTokenFp8ScaledMMLinearKernel, + ) + + # hipb_mm uses a transposed-result GEMM internally, so the flattened token + # count becomes the effective N dimension passed into hipBLASLt. Keep it + # aligned to avoid heuristic failures for tiny N. + x = torch.randn(2, 16, weight_shape[1], dtype=torch.bfloat16, device="cuda") + bias = torch.randn(weight_shape[0], dtype=torch.bfloat16, device="cuda") + + try: + out = layer(x, bias) + except RuntimeError as exc: + _skip_if_no_hipb_mm_solution(exc) + raise + + x_2d = x.view(-1, x.shape[-1]) + x_q, x_scale = layer.kernel.quant_fp8( + x_2d, + layer.input_scale, + layer.input_scale_ub, + ) + try: + # process_weights_after_loading already applies the trailing `.t()` on + # the shuffled weight and the `.t().contiguous()` on the weight scale, + # so the raw aiter call uses them directly. + expected = aiter.hipb_mm( + x_q, + layer.weight, + solution_index=-1, + bias=bias, + out_dtype=torch.bfloat16, + scaleA=x_scale, + scaleB=layer.weight_scale, + scaleOut=None, + bpreshuffle=True, + ).view(*out.shape) + except RuntimeError as exc: + _skip_if_no_hipb_mm_solution(exc) + raise + + assert isinstance(layer.kernel, AiterHipbMMPerTokenFp8ScaledMMLinearKernel) + assert out.shape == (2, 16, weight_shape[0]) + torch.testing.assert_close(out, expected) + + +def test_hipb_mm_kernel_forward_accuracy(enable_hipb_mm_kernel): + """Kernel output should match a dequantized fp32 reference within + fp8 per-token / per-channel quantization noise.""" + weight_shape = (512, 4096) # (N, K) + num_tokens = 32 + _check_bpreshuffle_runtime_support(weight_shape, num_tokens=num_tokens) + + fp8_dtype = current_platform.fp8_dtype() + fp8_max = torch.finfo(fp8_dtype).max + device = torch.device("cuda") + + # Build a bf16 weight and quantize per output channel (one scale per row). + w_bf16 = torch.randn(weight_shape, dtype=torch.bfloat16, device=device) + w_amax = w_bf16.abs().amax(dim=1, keepdim=True).to(torch.float32) + w_scale = (w_amax / fp8_max).clamp(min=1e-12) + w_fp8 = (w_bf16.to(torch.float32) / w_scale).clamp(-fp8_max, fp8_max).to(fp8_dtype) + w_dequant = w_fp8.to(torch.float32) * w_scale + + bias = torch.randn(weight_shape[0], dtype=torch.bfloat16, device=device) + + layer = torch.nn.Module() + # Pre-`process_weights_after_loading` convention: weight stored as the + # `[K, N]` view of the fp8 tensor. + layer.weight = torch.nn.Parameter(w_fp8.t(), requires_grad=False) + layer.weight_scale = torch.nn.Parameter(w_scale, requires_grad=False) + layer.input_scale = None + layer.input_scale_ub = None + + kernel = AiterHipbMMPerTokenFp8ScaledMMLinearKernel( + _make_config(weight_shape=weight_shape), + layer_param_names=("weight", "weight_scale", "input_scale", "input_scale_ub"), + ) + kernel.process_weights_after_loading(layer) + + x = torch.randn(num_tokens, weight_shape[1], dtype=torch.bfloat16, device=device) + + try: + out = kernel.apply_weights(layer, x, bias) + except RuntimeError as exc: + _skip_if_no_hipb_mm_solution(exc) + raise + + # Reference: quantize x per-token the same way the kernel does, then run + # the matmul in fp32 against the dequantized weight. This isolates plumbing + # / reduction bugs from inherent fp8 quantization noise. + x_amax = x.abs().amax(dim=1, keepdim=True).to(torch.float32) + x_scale_ref = (x_amax / fp8_max).clamp(min=1e-12) + x_q = (x.to(torch.float32) / x_scale_ref).clamp(-fp8_max, fp8_max).to(fp8_dtype) + x_dequant = x_q.to(torch.float32) * x_scale_ref + expected = (x_dequant @ w_dequant.t() + bias.to(torch.float32)).to(torch.bfloat16) + + assert out.shape == (num_tokens, weight_shape[0]) + # K=4096 fp8 reduction leaves room for accumulation order drift and + # catastrophic cancellation on near-zero outputs; tolerances are loose + # enough to absorb that but tight enough to catch wrong layouts, missing + # bias, swapped scales, etc. + torch.testing.assert_close(out, expected, atol=5.0, rtol=0.1) + + +def test_hipb_mm_kernel_online_tuning_writes_csv( + enable_hipb_mm_kernel, + monkeypatch: pytest.MonkeyPatch, + tmp_path, +): + weight_shape = (256, 4096) + cache_file = tmp_path / "hip_online_tuning_res.csv" + + _check_bpreshuffle_runtime_support(weight_shape, num_tokens=16) + + monkeypatch.setenv("HIP_ONLINE_TUNING", "1") + monkeypatch.chdir(tmp_path) + + layer = TestFP8Layer( + weight_shape=weight_shape, + activation_quant_key=kFp8DynamicTokenSym, + weight_quant_key=kFp8StaticChannelSym, + input_dtype=torch.bfloat16, + out_dtype=torch.bfloat16, + device=torch.device("cuda"), + force_kernel=AiterHipbMMPerTokenFp8ScaledMMLinearKernel, + ) + + # The effective heuristic N dimension is the flattened token count. + x = torch.randn(16, weight_shape[1], dtype=torch.bfloat16, device="cuda") + try: + out = layer(x) + except RuntimeError as exc: + _skip_if_no_hipb_mm_solution(exc) + raise + torch.accelerator.synchronize() + + assert out.shape == (16, weight_shape[0]) + assert cache_file.exists() + + # hipb_mm records the internal GEMM dimensions used by hipBLASLt after its + # transposed-result transformation. + row = _find_csv_row( + str(cache_file), + m=weight_shape[0], + n=x.shape[0], + k=weight_shape[1], + ) + assert row is not None diff --git a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py index 7ce4705510a2..1b39491ab346 100644 --- a/vllm/model_executor/kernels/linear/scaled_mm/aiter.py +++ b/vllm/model_executor/kernels/linear/scaled_mm/aiter.py @@ -223,9 +223,9 @@ def is_supported( if not rocm_aiter_ops.is_linear_hipbmm_enabled(): return ( False, - "requires setting `VLLM_ROCM_USE_AITER=1` " - "and `VLLM_ROCM_USE_AITER_LINEAR=1` " - "and `VLLM_ROCM_USE_AITER_LINEAR_HIPBMM=1`. ", + "requires setting `VLLM_ROCM_USE_AITER=1`, " + "`VLLM_ROCM_USE_AITER_LINEAR=1`, " + "and `VLLM_ROCM_USE_AITER_LINEAR_HIPBMM=1`.", ) try: import aiter # noqa: F401 From f6aed71f18f73246095e1f63303c3a96e78ea73f Mon Sep 17 00:00:00 2001 From: hanlin12 Date: Fri, 29 May 2026 03:05:05 +0000 Subject: [PATCH 19/19] fix pre-commit Signed-off-by: hanlin12 --- .../aiter/test_aiter_hipb_mm_linear_kernel.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/rocm/aiter/test_aiter_hipb_mm_linear_kernel.py b/tests/rocm/aiter/test_aiter_hipb_mm_linear_kernel.py index cb19f993389c..92017e95cb7a 100644 --- a/tests/rocm/aiter/test_aiter_hipb_mm_linear_kernel.py +++ b/tests/rocm/aiter/test_aiter_hipb_mm_linear_kernel.py @@ -2,8 +2,8 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import csv -import importlib.util import importlib +import importlib.util import os import pytest @@ -11,12 +11,12 @@ from tests.utils import TestFP8Layer from vllm._aiter_ops import rocm_aiter_ops -from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import ( - FP8ScaledMMLinearLayerConfig, -) from vllm.model_executor.kernels.linear.scaled_mm.aiter import ( AiterHipbMMPerTokenFp8ScaledMMLinearKernel, ) +from vllm.model_executor.kernels.linear.scaled_mm.ScaledMMLinearKernel import ( + FP8ScaledMMLinearLayerConfig, +) from vllm.model_executor.layers.quantization.utils.quant_utils import ( kFp8DynamicTokenSym, kFp8StaticChannelSym, @@ -95,9 +95,7 @@ def _check_bpreshuffle_runtime_support(weight_shape: tuple[int, int], num_tokens import aiter from aiter.ops.shuffle import shuffle_weight - x = torch.randn( - num_tokens, weight_shape[1], dtype=torch.bfloat16, device="cuda" - ) + x = torch.randn(num_tokens, weight_shape[1], dtype=torch.bfloat16, device="cuda") w = torch.randn(weight_shape, dtype=torch.bfloat16, device="cuda") aiter.hipb_create_extension() @@ -188,7 +186,8 @@ def test_hipb_mm_kernel_can_implement_success(enable_hipb_mm_kernel): ), ( _make_config(weight_shape=(8, 4090)), - "requires N >= 16 and both N and K divisible by 16, received N=8 and K=4090.", + "requires N >= 16 and both N and K divisible by 16, " + "received N=8 and K=4090.", ), ], )