Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
93 commits
Select commit Hold shift + click to select a range
974e682
first try
vllmellm Oct 28, 2025
e54e572
fix int8 path
vllmellm Oct 30, 2025
c05027f
clean up; fix quark path
vllmellm Oct 30, 2025
c089ea5
update quark fp8 path; format
vllmellm Oct 30, 2025
38825fc
Merge branch 'main' into refactor-fp8-linear
vllmellm Oct 31, 2025
423e2a6
reduce logging boilerplate; update fp8 path
vllmellm Oct 31, 2025
dd00106
reduce kernel init boilerplate
vllmellm Oct 31, 2025
7d36148
update ptpc path; bug fixes
vllmellm Oct 31, 2025
1f65cd5
revert input scale upper bounds
vllmellm Oct 31, 2025
5fbe76b
format; update fbgemm path
vllmellm Oct 31, 2025
e845035
bug fix
vllmellm Oct 31, 2025
d92c23b
fix types; reduce boilerplate for int8
vllmellm Nov 1, 2025
8e8218e
Merge remote-tracking branch 'origin/main' into refactor-fp8-linear
vllmellm Nov 1, 2025
4ce0ba2
format
vllmellm Nov 1, 2025
dd5a70e
update unit tests to use ScaledMMLinearKernels
vllmellm Nov 1, 2025
52ff537
update modelopt path
vllmellm Nov 1, 2025
b13c4bb
remove FP8LinearOps
vllmellm Nov 1, 2025
7794009
add missing arg
vllmellm Nov 3, 2025
a8010c7
flash_infer missing out dtype bug fix
vllmellm Nov 3, 2025
f5e6cd9
prefer QuantKey over ScaledMMLinearQuantStrategy
vllmellm Nov 4, 2025
a76f7bb
rename flash_infer.py to flashinfer.py
vllmellm Nov 4, 2025
f10171c
correct minimum capability req for channelwise torch
vllmellm Nov 4, 2025
fb72ec8
add missing kernels for cuda dispatch
vllmellm Nov 4, 2025
93fb707
reduce test boilerplate
vllmellm Nov 4, 2025
abf597e
fix quant key selection for ct; remove register_paramter calls; format
vllmellm Nov 4, 2025
aaa0d55
format
vllmellm Nov 4, 2025
7fb4657
implement apply func in base FP8ScaledMMLinearKernel class
vllmellm Nov 7, 2025
56a05cd
add minimal documentation for torch scaled mm base class
vllmellm Nov 7, 2025
9ff9b44
use for loops for fp8 linear layers init in tests
vllmellm Nov 7, 2025
cfb476f
minor fixes
vllmellm Nov 7, 2025
e47d55b
force kernels for tests
vllmellm Nov 7, 2025
edb6d43
ensure static scales for ChannelWiseTorchScaledMMLinearKernel; remove…
vllmellm Nov 7, 2025
858765f
fix output padding for torch _scaled_mm
vllmellm Nov 10, 2025
4c596a0
Merge remote-tracking branch 'origin/main' into refactor-fp8-linear
vllmellm Nov 11, 2025
65ecf48
optional input scales
vllmellm Nov 12, 2025
405d280
remove maybe_create_device_identity
vllmellm Nov 12, 2025
9784a0c
Merge remote-tracking branch 'origin/main' into refactor-fp8-linear
vllmellm Nov 13, 2025
10eebd4
add CPU kernels; fix fp8 quant type selection
vllmellm Nov 14, 2025
cbfcff3
add dynamic per tensor fallback
vllmellm Nov 18, 2025
9d94f5d
Merge remote-tracking branch 'origin/main' into refactor-fp8-linear
vllmellm Nov 18, 2025
1c5f633
format
vllmellm Nov 18, 2025
231f442
Merge remote-tracking branch 'origin/main' into refactor-fp8-linear
vllmellm Nov 24, 2025
d80e0f9
Merge branch 'main' into refactor-fp8-linear
vllmellm Dec 3, 2025
8937e96
Merge remote-tracking branch 'origin/main' into refactor-fp8-linear
vllmellm Dec 10, 2025
4e488da
fix merge artifacts
vllmellm Dec 10, 2025
52e2a31
Merge remote-tracking branch 'origin/main' into refactor-fp8-linear
vllmellm Dec 11, 2025
26fbb46
Merge branch 'main' into refactor-fp8-linear
vllmellm Dec 11, 2025
34e305e
revert gtoup shape import
vllmellm Dec 11, 2025
fb78f30
Merge remote-tracking branch 'origin/refactor-fp8-linear' into refact…
vllmellm Dec 11, 2025
a165967
Merge remote-tracking branch 'origin/main' into HEAD
vllmellm Dec 24, 2025
d3fc072
format
vllmellm Dec 24, 2025
fcf9b4c
Merge branch 'main' into refactor-fp8-linear
tjtanaa Jan 5, 2026
f5e24a4
remove get_min_capability
vllmellm Jan 6, 2026
c20f0b6
simplify test fp8 layer
vllmellm Jan 6, 2026
27878c5
simplify fp8 scaled mm
vllmellm Jan 7, 2026
06935f0
fix can implement and is supported assertions
vllmellm Jan 7, 2026
6668ffb
correct type for Bs
vllmellm Jan 7, 2026
686d0d1
Merge remote-tracking branch 'origin/main' into HEAD
vllmellm Jan 7, 2026
da53e6d
remove assym models from amd correctness tests
vllmellm Jan 8, 2026
ecee3d9
Merge remote-tracking branch 'origin/main' into HEAD
vllmellm Jan 8, 2026
d7e6e3b
revert mm model test config
vllmellm Jan 8, 2026
f4277fb
Merge branch 'main' into refactor-fp8-linear
vllmellm Jan 8, 2026
d749132
default to None for get_layer_params
vllmellm Jan 8, 2026
14b4d53
Merge remote-tracking branch 'origin/refactor-fp8-linear' into HEAD
vllmellm Jan 8, 2026
fc6c966
default to None only for optional params
vllmellm Jan 8, 2026
52c7ec8
revert small model config for gsm8k tests
vllmellm Jan 8, 2026
d8ee1b1
Merge branch 'main' into refactor-fp8-linear
tjtanaa Jan 9, 2026
8f0039a
minor fixes
vllmellm Jan 9, 2026
c52cfae
Merge remote-tracking branch 'origin/refactor-fp8-linear' into HEAD
vllmellm Jan 9, 2026
25d0d20
Merge remote-tracking branch 'origin/main' into HEAD
vllmellm Jan 12, 2026
d691155
minor fixes
vllmellm Jan 12, 2026
b97a08f
minor fixes
vllmellm Jan 12, 2026
6ce94db
format
vllmellm Jan 12, 2026
d4ffc95
reshape without padding
vllmellm Jan 13, 2026
a897ac6
minor changes; add Todo's
vllmellm Jan 13, 2026
6df3d61
consistent names
vllmellm Jan 13, 2026
81814c7
simplify act quant selection
vllmellm Jan 13, 2026
cd25901
link todo to github issue
vllmellm Jan 13, 2026
12296fe
remove skipped tests
vllmellm Jan 13, 2026
a2cbb58
Merge branch 'main' into refactor-fp8-linear
vllmellm Jan 14, 2026
cec9bf6
Merge branch 'main' into refactor-fp8-linear
vllmellm Jan 14, 2026
96508e5
fix CUDA unit tests
vllmellm Jan 14, 2026
0b40e43
Merge branch 'main' into refactor-fp8-linear
vllmellm Jan 14, 2026
e09f68e
Merge branch 'main' into refactor-fp8-linear
tjtanaa Jan 15, 2026
bdaaad4
fix cuda ci
vllmellm Jan 15, 2026
7b3af09
force cutlass kernel for attention fusion test
vllmellm Jan 15, 2026
9515da3
Merge branch 'main' into refactor-fp8-linear
tjtanaa Jan 15, 2026
6cf65dc
fix cuda ci
vllmellm Jan 15, 2026
b5c1c70
Merge branch 'main' into refactor-fp8-linear
tjtanaa Jan 16, 2026
1f95ab9
fix cutlass kernel hang
vllmellm Jan 16, 2026
ce381f0
Merge branch 'main' into refactor-fp8-linear
vllmellm Jan 17, 2026
99a5218
add comment
vllmellm Jan 19, 2026
8bf199f
Merge branch 'main' into refactor-fp8-linear
vllmellm Jan 20, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .buildkite/lm-eval-harness/configs/models-small-rocm.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Qwen2.5-1.5B-Instruct.yaml
Meta-Llama-3.2-1B-Instruct-INT8-compressed-tensors.yaml
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
Qwen2.5-VL-3B-Instruct-FP8-dynamic.yaml
Qwen1.5-MoE-W4A16-compressed-tensors.yaml
44 changes: 17 additions & 27 deletions tests/compile/distributed/test_fusion_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@
initialize_model_parallel,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
GroupShape,
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import set_random_seed

from ...utils import has_module_attribute, multi_gpu_test
from ...utils import TestFP8Layer, has_module_attribute, multi_gpu_test
from ..backend import TestBackend


Expand Down Expand Up @@ -76,49 +75,40 @@ def ops_in_model_after(self):


class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
quant_key = kFp8StaticTensorSym

def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.w = [
torch.rand(hidden_size, hidden_size)
.to(dtype=current_platform.fp8_dtype())
.t()
for _ in range(3)
self.fp8_linear_layers = [
TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
)
for i in range(3)
]

self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)

self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]

def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)

z2 = self.fp8_linear.apply(
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
z2 = self.fp8_linear_layers[0](y)

x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)

z3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
z3 = self.fp8_linear_layers[1](y2)

x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here

z4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
z4 = self.fp8_linear_layers[2](y3)

x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
Expand All @@ -130,7 +120,7 @@ def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.static_scaled_fp8_quant.default
if self.fp8_linear.quant_fp8.enabled()
if self.fp8_linear_layers[0].is_quant_fp8_enabled()
else torch.ops.aten.reciprocal.default,
]

Expand Down
43 changes: 17 additions & 26 deletions tests/compile/distributed/test_sequence_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
initialize_model_parallel,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables
from vllm.utils.torch_utils import set_random_seed

from ...utils import multi_gpu_test
from ...utils import TestFP8Layer, multi_gpu_test
from ..backend import TestBackend

FP8_DTYPE = current_platform.fp8_dtype()
Expand Down Expand Up @@ -94,50 +95,40 @@ def ops_in_model(self):


class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
quant_key = kFp8StaticTensorSym

def __init__(self, hidden_size=16, eps=1e-6):
super().__init__()
self.vllm_config = get_current_vllm_config()
self.hidden_size = hidden_size
self.eps = eps
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.w = [
torch.rand(hidden_size, hidden_size)
.to(dtype=current_platform.fp8_dtype())
.t()
for _ in range(3)
self.fp8_linear_layers = [
TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
)
for i in range(3)
]

self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)

self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]

def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)

z2 = self.fp8_linear.apply(
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
z2 = self.fp8_linear_layers[0](y)

x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)

z3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
z3 = self.fp8_linear_layers[1](y2)

x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here

z4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
z4 = self.fp8_linear_layers[2](y3)
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
Expand All @@ -160,7 +151,7 @@ def ops_in_model(self):
return [
torch.ops._C.fused_add_rms_norm.default,
]
elif self.fp8_linear.quant_fp8.enabled():
elif any(layer.is_quant_fp8_enabled() for layer in self.fp8_linear_layers):
return [
torch.ops._C.static_scaled_fp8_quant.default,
]
Expand Down
41 changes: 19 additions & 22 deletions tests/compile/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,36 @@
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform

from ..utils import TestFP8Layer
from .backend import TestBackend

TEST_FP8 = current_platform.supports_fp8()
FP8_DTYPE = current_platform.fp8_dtype()


class TestSiluMul(torch.nn.Module):
quant_key = kFp8StaticTensorSym

def __init__(self, hidden_size: int = 128):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32)

if TEST_FP8:
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
self.fp8_linear = TestFP8Layer(
weight_shape=(hidden_size, hidden_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
)

def forward(self, x):
y = self.silu_and_mul(x)
if TEST_FP8:
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
return x2
return self.fp8_linear(y)
else:
return y

Expand All @@ -67,6 +67,8 @@ def ops_not_in_model(self):


class TestFusedAddRMSNorm(torch.nn.Module):
quant_key = kFp8StaticTensorSym

def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
Expand All @@ -81,11 +83,11 @@ def __init__(self, hidden_size=16, intermediate_size=32):
torch.nn.init.normal_(self.gate_proj, std=0.02)

if TEST_FP8:
self.fp8_linear = Fp8LinearOp(act_quant_static=True)

self.scale = torch.rand(1, dtype=torch.float32)
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
self.wscale = torch.rand(1, dtype=torch.float32)
self.fp8_linear = TestFP8Layer(
weight_shape=(hidden_size, intermediate_size),
activation_quant_key=self.quant_key,
weight_quant_key=self.quant_key,
)

def forward(self, hidden_states, residual):
# Reshape input
Expand All @@ -100,12 +102,7 @@ def forward(self, hidden_states, residual):

if TEST_FP8:
# scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply(
norm_output,
self.w,
self.wscale,
input_scale=self.scale.to(norm_output.device),
)
fp8_linear_result = self.fp8_linear(norm_output)

return fp8_linear_result, residual_output

Expand Down
Loading