Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
974e682
first try
vllmellm Oct 28, 2025
e54e572
fix int8 path
vllmellm Oct 30, 2025
c05027f
clean up; fix quark path
vllmellm Oct 30, 2025
c089ea5
update quark fp8 path; format
vllmellm Oct 30, 2025
38825fc
Merge branch 'main' into refactor-fp8-linear
vllmellm Oct 31, 2025
423e2a6
reduce logging boilerplate; update fp8 path
vllmellm Oct 31, 2025
dd00106
reduce kernel init boilerplate
vllmellm Oct 31, 2025
7d36148
update ptpc path; bug fixes
vllmellm Oct 31, 2025
1f65cd5
revert input scale upper bounds
vllmellm Oct 31, 2025
5fbe76b
format; update fbgemm path
vllmellm Oct 31, 2025
e845035
bug fix
vllmellm Oct 31, 2025
d92c23b
fix types; reduce boilerplate for int8
vllmellm Nov 1, 2025
8e8218e
Merge remote-tracking branch 'origin/main' into refactor-fp8-linear
vllmellm Nov 1, 2025
4ce0ba2
format
vllmellm Nov 1, 2025
dd5a70e
update unit tests to use ScaledMMLinearKernels
vllmellm Nov 1, 2025
52ff537
update modelopt path
vllmellm Nov 1, 2025
b13c4bb
remove FP8LinearOps
vllmellm Nov 1, 2025
7794009
add missing arg
vllmellm Nov 3, 2025
a8010c7
flash_infer missing out dtype bug fix
vllmellm Nov 3, 2025
f5e6cd9
prefer QuantKey over ScaledMMLinearQuantStrategy
vllmellm Nov 4, 2025
a76f7bb
rename flash_infer.py to flashinfer.py
vllmellm Nov 4, 2025
f10171c
correct minimum capability req for channelwise torch
vllmellm Nov 4, 2025
fb72ec8
add missing kernels for cuda dispatch
vllmellm Nov 4, 2025
93fb707
reduce test boilerplate
vllmellm Nov 4, 2025
abf597e
fix quant key selection for ct; remove register_paramter calls; format
vllmellm Nov 4, 2025
aaa0d55
format
vllmellm Nov 4, 2025
7fb4657
implement apply func in base FP8ScaledMMLinearKernel class
vllmellm Nov 7, 2025
56a05cd
add minimal documentation for torch scaled mm base class
vllmellm Nov 7, 2025
9ff9b44
use for loops for fp8 linear layers init in tests
vllmellm Nov 7, 2025
cfb476f
minor fixes
vllmellm Nov 7, 2025
e47d55b
force kernels for tests
vllmellm Nov 7, 2025
edb6d43
ensure static scales for ChannelWiseTorchScaledMMLinearKernel; remove…
vllmellm Nov 7, 2025
45a3008
feat: Integrate AITER bpreshuffle and ck operators on top of fp8 refa…
vllmellm Nov 10, 2025
858765f
fix output padding for torch _scaled_mm
vllmellm Nov 10, 2025
45803e1
Merge remote-tracking branch 'origin/refactor-fp8-linear' into refact…
vllmellm Nov 10, 2025
4c596a0
Merge remote-tracking branch 'origin/main' into refactor-fp8-linear
vllmellm Nov 11, 2025
65ecf48
optional input scales
vllmellm Nov 12, 2025
405d280
remove maybe_create_device_identity
vllmellm Nov 12, 2025
9784a0c
Merge remote-tracking branch 'origin/main' into refactor-fp8-linear
vllmellm Nov 13, 2025
10eebd4
add CPU kernels; fix fp8 quant type selection
vllmellm Nov 14, 2025
686b3ec
Merge remote-tracking branch 'origin/refactor-fp8-linear' into refact…
vllmellm Nov 17, 2025
679a7cf
WIP: Integrate Aiter bpreshuffle and ck kernels
vllmellm Nov 17, 2025
30624ea
sync upstream
vllmellm Dec 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions tests/compile/distributed/test_fusion_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,14 @@
initialize_model_parallel,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
GroupShape,
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables

from ...utils import has_module_attribute, multi_gpu_test
from ..backend import TestBackend
from ..utils import TestFP8Layer, has_module_attribute, multi_gpu_test
from .backend import TestBackend


class TestAllReduceRMSNormModel(torch.nn.Module):
Expand Down Expand Up @@ -75,49 +74,51 @@ def ops_in_model_after(self):


class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
quant_key = kFp8StaticTensorSym

def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
super().__init__()
self.hidden_size = hidden_size
self.eps = eps
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.w = [
self.input_scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.weight = [
torch.rand(hidden_size, hidden_size)
.to(dtype=current_platform.fp8_dtype())
.t()
for _ in range(3)
]

self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)

self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.fp8_linear_layers = [
TestFP8Layer(
self.quant_key,
self.quant_key,
self.weight[i],
self.wscale[i],
input_scale=self.input_scale[i],
)
for i in range(3)
]

def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)

z2 = self.fp8_linear.apply(
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
z2 = self.fp8_linear_layers[0](y)

x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)

z3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
z3 = self.fp8_linear_layers[1](y2)

x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here

z4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
z4 = self.fp8_linear_layers[2](y3)

x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
Expand All @@ -129,7 +130,7 @@ def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.static_scaled_fp8_quant.default
if self.fp8_linear.quant_fp8.enabled()
if self.fp8_linear_layers[0].is_quant_fp8_enabled()
else torch.ops.aten.reciprocal.default,
]

Expand Down
40 changes: 18 additions & 22 deletions tests/compile/distributed/test_sequence_parallelism.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,14 @@
initialize_model_parallel,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.platforms import current_platform
from vllm.utils.system_utils import update_environment_variables

from ...utils import multi_gpu_test
from ..backend import TestBackend
from ..utils import TestFP8Layer, multi_gpu_test
from .backend import TestBackend

FP8_DTYPE = current_platform.fp8_dtype()
prompts = [
Expand Down Expand Up @@ -93,6 +94,8 @@ def ops_in_model(self):


class TestAllReduceRMSNormStaticQuantFP8Model(torch.nn.Module):
quant_key = kFp8StaticTensorSym

def __init__(self, hidden_size=16, eps=1e-6):
super().__init__()
self.vllm_config = get_current_vllm_config()
Expand All @@ -101,42 +104,35 @@ def __init__(self, hidden_size=16, eps=1e-6):
self.norm = [RMSNorm(hidden_size, eps) for i in range(4)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]
self.w = [
torch.rand(hidden_size, hidden_size)
.to(dtype=current_platform.fp8_dtype())
.t()
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(3)
]

self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)

self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(3)]

self.fp8_linears = [
TestFP8Layer(
self.quant_key, self.quant_key, self.w[i], self.wscale[i], self.scale[i]
)
for i in range(3)
]

def forward(self, hidden_states):
# avoid having graph input be an arg to a pattern directly
z = torch.relu(hidden_states)
x = resid = tensor_model_parallel_all_reduce(z)
y = self.norm[0](x)

z2 = self.fp8_linear.apply(
y, self.w[0], self.wscale[0], input_scale=self.scale[0]
)
z2 = self.fp8_linears[0](y)

x2 = tensor_model_parallel_all_reduce(z2)
y2, resid = self.norm[1](x2, resid)

z3 = self.fp8_linear.apply(
y2, self.w[1], self.wscale[1], input_scale=self.scale[1]
)
z3 = self.fp8_linears[1](y2)

x3 = tensor_model_parallel_all_reduce(z3)
y3, resid = self.norm[2](x3, resid) # use resid here

z4 = self.fp8_linear.apply(
y3, self.w[2], self.wscale[2], input_scale=self.scale[2]
)
z4 = self.fp8_linears[2].apply(y3)
x4 = tensor_model_parallel_all_reduce(z4)
y4, resid = self.norm[3](x4, resid) # use resid here
return y4
Expand Down
54 changes: 32 additions & 22 deletions tests/compile/test_functionalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,36 +20,41 @@
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
from vllm.model_executor.layers.quantization.utils.w8a8_utils import Fp8LinearOp
from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.platforms import current_platform

from ..utils import TestFP8Layer
from .backend import TestBackend

TEST_FP8 = current_platform.supports_fp8()
FP8_DTYPE = current_platform.fp8_dtype()


class TestSiluMul(torch.nn.Module):
quant_key = kFp8StaticTensorSym

def __init__(self, hidden_size: int = 128):
super().__init__()
self.silu_and_mul = SiluAndMul()
self.wscale = torch.rand(1, dtype=torch.float32)
self.scale = torch.rand(1, dtype=torch.float32)

self.weight_scale = torch.rand(1, dtype=torch.float32)
self.input_scale = torch.rand(1, dtype=torch.float32)
if TEST_FP8:
self.w = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = Fp8LinearOp(
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
self.weight = torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
self.fp8_linear = TestFP8Layer(
self.quant_key,
self.quant_key,
self.weight,
self.weight_scale,
self.input_scale,
)

def forward(self, x):
y = self.silu_and_mul(x)
if TEST_FP8:
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
return x2
return self.fp8_linear(y)
else:
return y

Expand All @@ -67,6 +72,8 @@ def ops_not_in_model(self):


class TestFusedAddRMSNorm(torch.nn.Module):
quant_key = kFp8StaticTensorSym

def __init__(self, hidden_size=16, intermediate_size=32):
super().__init__()
self.hidden_size = hidden_size
Expand All @@ -81,11 +88,18 @@ def __init__(self, hidden_size=16, intermediate_size=32):
torch.nn.init.normal_(self.gate_proj, std=0.02)

if TEST_FP8:
self.fp8_linear = Fp8LinearOp(act_quant_static=True)

self.scale = torch.rand(1, dtype=torch.float32)
self.w = torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
self.wscale = torch.rand(1, dtype=torch.float32)
self.weight = (
torch.rand(hidden_size, intermediate_size).to(dtype=FP8_DTYPE).t()
)
self.weight_scale = torch.rand(1, dtype=torch.float32)
self.input_scale = torch.rand(1, dtype=torch.float32)
self.fp8_linear = TestFP8Layer(
self.quant_key,
self.quant_key,
self.weight,
self.weight_scale,
self.input_scale,
)

def forward(self, hidden_states, residual):
# Reshape input
Expand All @@ -99,13 +113,9 @@ def forward(self, hidden_states, residual):
norm_output, residual_output = self.norm(mm, residual)

if TEST_FP8:
self.input_scale = self.input_scale.to(norm_output.device)
# scaled_mm with static input quantization
fp8_linear_result = self.fp8_linear.apply(
norm_output,
self.w,
self.wscale,
input_scale=self.scale.to(norm_output.device),
)
fp8_linear_result = self.fp8_linear(norm_output)

return fp8_linear_result, residual_output

Expand Down
Loading