Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
8a542b7
create initial block scaled mm kernels and a common base
maralbahari Feb 5, 2026
0ebcf78
remove W8A8Fp8BlockLinearOp and adop mm kernel selection
maralbahari Feb 5, 2026
b76074c
remove W8A8Fp8BlockLinearOp from unit tests
maralbahari Feb 5, 2026
3c7049e
Update vllm/model_executor/layers/quantization/kernels/base.py
maralbahari Feb 5, 2026
08a893d
Update vllm/model_executor/layers/quantization/kernels/base.py
maralbahari Feb 5, 2026
9847109
Update vllm/model_executor/layers/quantization/kernels/scaled_mm/aite…
maralbahari Feb 5, 2026
5d58935
Update vllm/model_executor/layers/quantization/kernels/scaled_mm/cuda.py
maralbahari Feb 5, 2026
9887678
Update vllm/model_executor/layers/quantization/kernels/scaled_mm/Bloc…
maralbahari Feb 5, 2026
4b53675
fix pre-commit issues and typings
maralbahari Feb 5, 2026
acac7c1
imporve typing
maralbahari Feb 5, 2026
61bfb5b
Merge remote-tracking branch 'origin/2n-block-scaled-rfc-pr' into 3n-…
maralbahari Feb 5, 2026
3363c88
add missing kwargs for aiter fp8 block scaled mm func and return stat…
maralbahari Feb 9, 2026
79951e2
Merge remote-tracking branch 'origin/2n-block-scaled-rfc-pr' into 3n-…
maralbahari Feb 9, 2026
6465faa
fix f-string
maralbahari Feb 9, 2026
5b3c2e1
Merge remote-tracking branch 'origin/2n-block-scaled-rfc-pr' into 3n-…
maralbahari Feb 9, 2026
8dd23bd
Merge remote-tracking branch 'origin/main' into 3n-block-scaled-rfc-pr
maralbahari Feb 9, 2026
320ced0
improve documenetation and fix typings in init_fp8_linear_kernel
maralbahari Feb 9, 2026
d0cd8a2
Merge remote-tracking branch 'origin/2n-block-scaled-rfc-pr' into 3n-…
maralbahari Feb 9, 2026
f555f75
Merge remote-tracking branch 'origin/main' into 2n-block-scaled-rfc-pr
maralbahari Feb 23, 2026
614cef5
Merge remote-tracking branch 'origin/2n-block-scaled-rfc-pr' into 3n-…
maralbahari Feb 23, 2026
c43b6cd
Merge remote-tracking branch 'origin/main' into 2n-block-scaled-rfc-pr
maralbahari Feb 24, 2026
ce88d6e
Merge remote-tracking branch 'origin/2n-block-scaled-rfc-pr' into 3n-…
maralbahari Feb 25, 2026
08d6a54
fix import error
maralbahari Feb 25, 2026
4bc9347
fix imports
maralbahari Feb 25, 2026
d001db8
fix import
maralbahari Feb 25, 2026
de82fd1
use the same variable name for inpt quantization to follow scaled_mm
maralbahari Feb 25, 2026
7a26e60
address PR comments
maralbahari Feb 25, 2026
15c3d44
Merge remote-tracking branch 'origin/2n-block-scaled-rfc-pr' into 3n-…
maralbahari Feb 25, 2026
e5bbb6c
bugfixes
maralbahari Feb 26, 2026
cb46979
address PR comment
maralbahari Feb 26, 2026
f27d31a
bugfix compressed tensors
maralbahari Feb 26, 2026
d805795
fix unit tests
maralbahari Feb 26, 2026
bfcd522
add group_size check for cutlass and deep_gemm kernels and update fus…
maralbahari Feb 27, 2026
ca8b19d
fix wrong check on block fp8 cutlass can_implement
maralbahari Feb 27, 2026
00a7522
fix potential bugs in deepgemm
maralbahari Feb 27, 2026
e97a479
bugfix reading correct weight_scale in block scaled mm linear
maralbahari Mar 2, 2026
0236228
Merge remote-tracking branch 'origin/main' into 3n-block-scaled-rfc-pr
maralbahari Mar 6, 2026
6ce17a9
fix pre-commit issue
maralbahari Mar 6, 2026
2645041
Merge remote-tracking branch 'origin/main' into 3n-block-scaled-rfc-pr
maralbahari Mar 9, 2026
a3d7831
Merge branch 'main' into 3n-block-scaled-rfc-pr
tjtanaa Mar 11, 2026
ed5a54c
Merge branch 'main' into 3n-block-scaled-rfc-pr
tjtanaa Mar 19, 2026
c28dac9
initialize kernels in create_weights
maralbahari Mar 23, 2026
2478541
Merge remote-tracking branch 'origin/main' into 3n-block-scaled-rfc-pr
maralbahari Mar 23, 2026
dba5697
fix fusion unit tests
maralbahari Mar 23, 2026
f0ca1e9
fix fusion unit test and online fp8 quant
maralbahari Mar 23, 2026
f595112
fix pre-commit error
maralbahari Mar 23, 2026
55096ef
fix input_dtype
maralbahari Mar 23, 2026
64df301
fix unittest
maralbahari Mar 23, 2026
a08f623
Merge remote-tracking branch 'origin/main' into 3n-block-scaled-rfc-pr
maralbahari Mar 23, 2026
1d5c1b7
fix unit tests
maralbahari Mar 23, 2026
98f215b
Merge remote-tracking branch 'origin/main' into 3n-block-scaled-rfc-pr
maralbahari Mar 24, 2026
1b65c2e
fix unit test for test_modelopt
maralbahari Mar 24, 2026
f093d82
remove unused function.
maralbahari Mar 24, 2026
cbb0599
fix Quantization unit test
maralbahari Mar 24, 2026
05b7cc9
attemp to fix marlin fp8 quant fp8
maralbahari Mar 24, 2026
929d05d
Merge branch 'main' into 3n-block-scaled-rfc-pr
tjtanaa Mar 25, 2026
8593412
Merge branch 'main' into 3n-block-scaled-rfc-pr
tjtanaa Mar 25, 2026
5c73e37
Merge remote-tracking branch 'origin/main' into 3n-block-scaled-rfc-pr
maralbahari Mar 26, 2026
a07f484
Merge branch '3n-block-scaled-rfc-pr' of https://github.com/EmbeddedL…
maralbahari Mar 26, 2026
f3a1cd2
Merge remote-tracking branch 'origin/main' into 3n-block-scaled-rfc-pr
maralbahari Mar 26, 2026
c86b172
fix deepgemm ep2 accuracy issue
maralbahari Mar 28, 2026
cf0618c
Merge remote-tracking branch 'origin/main' into 3n-block-scaled-rfc-pr
maralbahari Mar 28, 2026
7930f5a
Merge branch 'main' into 3n-block-scaled-rfc-pr
tjtanaa Mar 28, 2026
f01ba9e
Merge branch 'main' into 3n-block-scaled-rfc-pr
maralbahari Mar 30, 2026
ad95ceb
avoid calling is_flashinfer_fp8_blockscale_gemm_supported as class va…
maralbahari Mar 30, 2026
f5348e8
Merge branch '3n-block-scaled-rfc-pr' of https://github.com/EmbeddedL…
maralbahari Mar 30, 2026
f012fae
fix torch compile issue with torch.cond
maralbahari Mar 30, 2026
9985919
fix torch.cond torch.compile errors
maralbahari Mar 31, 2026
9520a25
Merge branch 'main' into 3n-block-scaled-rfc-pr
vllmellm Mar 31, 2026
ab00fb7
bugfix wrong input quantization
maralbahari Mar 31, 2026
594fc5b
Merge remote-tracking branch 'origin/main' into 3n-block-scaled-rfc-pr
maralbahari Mar 31, 2026
ab2d3fc
fix torch.cond fx-graph break
maralbahari Apr 2, 2026
a3041f7
Merge remote-tracking branch 'origin/main' into 3n-block-scaled-rfc-pr
maralbahari Apr 2, 2026
42f3334
fix mxfp8 test fail
maralbahari Apr 3, 2026
697d747
clean code
maralbahari Apr 3, 2026
0498d02
Merge remote-tracking branch 'origin/main' into 3n-block-scaled-rfc-pr
maralbahari Apr 3, 2026
af9ec82
fix new attention mla fusion unit test
maralbahari Apr 3, 2026
f79b1fd
fix wrong skip condition
maralbahari Apr 3, 2026
5e76e75
fix mxfp8 unit test
maralbahari Apr 3, 2026
24a4f25
maybe fix cutlass block scaled gemm
maralbahari Apr 3, 2026
7e254a2
clean code
maralbahari Apr 3, 2026
46ffd25
Merge remote-tracking branch 'origin/main' into 3n-block-scaled-rfc-pr
maralbahari Apr 3, 2026
186fe8f
fix batch-invariant issue
maralbahari Apr 3, 2026
8056522
Merge remote-tracking branch 'origin/main' into 3n-block-scaled-rfc-pr
maralbahari Apr 6, 2026
9cb0ebf
fix online fp8
maralbahari Apr 6, 2026
bb9920f
fix pre-commit
maralbahari Apr 6, 2026
884b952
fix mxfp8 linearmethod
maralbahari Apr 6, 2026
8e6b3ab
fix pytorch compile test
maralbahari Apr 6, 2026
1a88320
Merge branch 'main' into 3n-block-scaled-rfc-pr
tjtanaa Apr 7, 2026
5b572ac
Merge branch 'main' into 3n-block-scaled-rfc-pr
tjtanaa Apr 7, 2026
92ba677
Merge branch 'main' into 3n-block-scaled-rfc-pr
tjtanaa Apr 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions benchmarks/kernels/benchmark_block_fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@
import torch

from vllm.benchmarks.lib.utils import default_vllm_config
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
W8A8BlockFp8LinearOp,
from vllm.model_executor.kernels.linear import (
init_fp8_linear_kernel,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
create_fp8_quant_key,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED,
Expand Down Expand Up @@ -70,11 +71,15 @@ def build_w8a8_block_fp8_runner(M, N, K, block_size, device, use_cutlass):
weight_group_shape = GroupShape(block_n, block_k)
act_quant_group_shape = GroupShape(1, block_k) # Per-token, per-group quantization

linear_op = W8A8BlockFp8LinearOp(
weight_group_shape=weight_group_shape,
act_quant_group_shape=act_quant_group_shape,
cutlass_block_fp8_supported=use_cutlass,
use_aiter_and_is_supported=False,
linear_op = init_fp8_linear_kernel(
weight_quant_key=create_fp8_quant_key(
static=True, group_shape=weight_group_shape
),
activation_quant_key=create_fp8_quant_key(
static=False, group_shape=act_quant_group_shape
),
out_dtype=torch.get_default_dtype(),
module_name="build_w8a8_block_fp8_runner",
)

def run():
Expand Down
15 changes: 11 additions & 4 deletions tests/compile/passes/distributed/test_fusion_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@


class TestAllReduceRMSNormModel(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
def __init__(
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
Expand Down Expand Up @@ -78,7 +80,9 @@ def ops_in_model_after(self):
class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
quant_key = kFp8StaticTensorSym

def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
def __init__(
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
Expand All @@ -88,6 +92,7 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
input_dtype=dtype,
)
for i in range(3)
]
Expand Down Expand Up @@ -127,7 +132,9 @@ def ops_in_model_before(self):


class TestAllReduceFusedAddRMSNormStaticQuantFP4Model(torch.nn.Module):
def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
def __init__(
self, hidden_size=16, token_num=16, eps=1e-6, dtype: torch.dtype = torch.float16
):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
Expand Down Expand Up @@ -314,7 +321,7 @@ def all_reduce_fusion_pass_on_test_model(
)

token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)
model = test_model_cls(hidden_size, token_num, dtype=dtype)

hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(self, hidden_size=16, eps=1e-6):
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
input_dtype=self.vllm_config.model_config.dtype,
)
for i in range(3)
]
Expand Down
3 changes: 3 additions & 0 deletions tests/compile/passes/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
ModelConfig,
PassConfig,
VllmConfig,
get_current_vllm_config,
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
Expand All @@ -49,6 +50,7 @@ def __init__(self, hidden_size: int = 128):
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
input_dtype=get_current_vllm_config().model_config.dtype,
)

def forward(self, x):
Expand Down Expand Up @@ -92,6 +94,7 @@ def __init__(self, hidden_size=16, intermediate_size=32):
weight_shape=(hidden_size, intermediate_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
input_dtype=get_current_vllm_config().model_config.dtype,
)

def forward(self, hidden_states, residual):
Expand Down
100 changes: 49 additions & 51 deletions tests/compile/passes/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import vllm.ir.ops
import vllm.plugins
from tests.compile.backend import TestBackend
from tests.utils import TestBlockFP8Layer, TestFP8Layer
from tests.utils import TestFP8Layer
from vllm._aiter_ops import IS_AITER_FOUND, rocm_aiter_ops
from vllm.compilation.passes.fusion.matcher_utils import QUANT_OPS
from vllm.compilation.passes.fusion.rms_quant_fusion import (
Expand All @@ -28,19 +28,23 @@
VllmConfig,
)
from vllm.model_executor.kernels.linear import (
AiterFp8BlockScaledMMKernel,
ChannelWiseTorchFP8ScaledMMLinearKernel,
CutlassFp8BlockScaledMMKernel,
CutlassFP8ScaledMMLinearKernel,
DeepGemmFp8BlockScaledMMKernel,
FlashInferFp8DeepGEMMDynamicBlockScaledKernel,
FlashInferFP8ScaledMMLinearKernel,
FP8ScaledMMLinearKernel,
PerTensorTorchFP8ScaledMMLinearKernel,
ROCmFP8ScaledMMLinearKernel,
RowWiseTorchFP8ScaledMMLinearKernel,
TritonFp8BlockScaledMMKernel,
_KernelT,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
QuantKey,
ScaleDesc,
create_fp8_quant_key,
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
cutlass_block_fp8_supported,
Expand All @@ -66,9 +70,12 @@
(PerTensorTorchFP8ScaledMMLinearKernel, GroupShape.PER_TENSOR),
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
# Blockwise group shapes (no kernel abstraction)
(None, GroupShape(1, 128)),
(None, GroupShape(1, 64)),
# Blockwise group shapes
(FlashInferFp8DeepGEMMDynamicBlockScaledKernel, GroupShape(1, 128)),
(CutlassFp8BlockScaledMMKernel, GroupShape(1, 128)),
(DeepGemmFp8BlockScaledMMKernel, GroupShape(1, 128)),
(TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
(TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
]

# ROCm kernels
Expand All @@ -80,8 +87,8 @@
# ChannelWiseTorchFP8ScaledMMLinearKernel only supports per-token
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN),
# Blockwise group shapes (no kernel abstraction)
(None, GroupShape(1, 128)),
(None, GroupShape(1, 64)),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know what is the equivalent of this (None, GroupShape(1, 64)), test case for this PR?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tjtanaa added (TritonFp8BlockScaledMMKernel, GroupShape(1, 64)) for rocm similar to cuda.

(TritonFp8BlockScaledMMKernel, GroupShape(1, 128)),
(TritonFp8BlockScaledMMKernel, GroupShape(1, 64)),
]

KERNEL_GROUPSHAPE_COMBINATIONS = (
Expand All @@ -100,8 +107,8 @@
# Per-token with ChannelWiseTorchFP8ScaledMMLinearKernel
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, True),
(ChannelWiseTorchFP8ScaledMMLinearKernel, GroupShape.PER_TOKEN, False),
# Blockwise (no kernel abstraction)
(None, GroupShape(1, 128), True),
# Blockwise
(AiterFp8BlockScaledMMKernel, GroupShape(1, 128), True),
]


Expand All @@ -110,8 +117,9 @@ def __init__(
self,
hidden_size: int,
eps: float,
force_kernel: FP8ScaledMMLinearKernel | None,
force_kernel: type[_KernelT] | None,
group_shape: GroupShape,
dtype: torch.dtype,
use_aiter_fusion: bool = False,
use_aiter_quant: bool = False,
*args,
Expand All @@ -129,54 +137,42 @@ def __init__(
is_blockwise = group_shape.is_per_group()

if is_blockwise:
act_quant_scale_desc = ScaleDesc(torch.float32, False, group_shape)
self.activation_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
block_size = group_shape.col
self.activation_quant_key = create_fp8_quant_key(
static=False, group_shape=group_shape
)
self.fp8_linear_layers = [
TestBlockFP8Layer(
weight_shape=(hidden_size, hidden_size),
group_shape=group_shape,
cutlass_block_fp8_supported=cutlass_block_fp8_supported(),
use_aiter_and_is_supported=use_aiter_quant,
transpose_weights=use_aiter_fusion,
)
for _ in range(3)
]

self.enable_quant_fp8_custom_op = (
False
if use_aiter_quant
else self.fp8_linear_layers[0].linear_op.input_quant_op.enabled()
self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=GroupShape(block_size, block_size)
)

else:
is_static = group_shape == GroupShape.PER_TENSOR
act_quant_scale_desc = ScaleDesc(torch.float32, is_static, group_shape)
w_quant_scale_desc = ScaleDesc(torch.float32, True, group_shape)
self.activation_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=act_quant_scale_desc, symmetric=True
self.activation_quant_key = create_fp8_quant_key(
is_static, group_shape=group_shape
)
self.weight_quant_key = QuantKey(
dtype=FP8_DTYPE, scale=w_quant_scale_desc, symmetric=True
self.weight_quant_key = create_fp8_quant_key(
static=True, group_shape=group_shape
)
self.fp8_linear_layers = [
TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
force_kernel=force_kernel,
)
for _ in range(3)
]

# Enable aiter quantization if requested
for layer in self.fp8_linear_layers:
layer.kernel.quant_fp8.use_aiter = use_aiter_quant
self.fp8_linear_layers = [
TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.activation_quant_key,
weight_quant_key=self.weight_quant_key,
force_kernel=force_kernel,
transpose_weights=use_aiter_fusion,
input_dtype=dtype,
)
for _ in range(3)
]

# Enable aiter quantization if requested
for layer in self.fp8_linear_layers:
layer.kernel.quant_fp8.use_aiter = use_aiter_quant

self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
0
].is_quant_fp8_enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear_layers[
0
].is_quant_fp8_enabled()

def forward(self, x):
# avoid having graph input be an arg to a pattern directly
Expand Down Expand Up @@ -354,6 +350,7 @@ def test_fusion_rmsnorm_quant(
eps=eps,
force_kernel=force_kernel,
group_shape=group_shape,
dtype=dtype,
use_aiter_fusion=False,
use_aiter_quant=False,
)
Expand Down Expand Up @@ -426,6 +423,7 @@ def test_aiter_fusion_rmsnorm_quant(
eps=eps,
force_kernel=force_kernel,
group_shape=group_shape,
dtype=dtype,
use_aiter_fusion=True, # Always use aiter fusion ops in aiter test
use_aiter_quant=use_aiter_quant_op, # Toggle aiter quantization
)
Expand Down
2 changes: 2 additions & 0 deletions tests/compile/passes/test_fusion_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(
self.kv_cache_dtype = kv_cache_dtype
self.device = device
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype

self.attn = Attention(
num_heads=self.num_qo_heads,
Expand Down Expand Up @@ -155,6 +156,7 @@ def __init__(self, *args, **kwargs):
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
device=self.device,
input_dtype=self.dtype,
)

w = kwargs.get("w")
Expand Down
2 changes: 2 additions & 0 deletions tests/compile/passes/test_mla_attn_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
self.kv_cache_dtype = kv_cache_dtype
self.device = device
self.vllm_config = vllm_config
self.dtype = vllm_config.model_config.dtype

# Create kv_b_proj (ColumnParallelLinear) on device.
# Reuse weights from prior model instance when available, because
Expand Down Expand Up @@ -190,6 +191,7 @@ def __init__(self, *args, **kwargs):
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
device=self.device,
input_dtype=self.dtype,
)

w = kwargs.get("w")
Expand Down
Loading
Loading