Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions tests/compile/test_fusion2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

import vllm.envs as envs
import vllm.plugins
from vllm.compilation.fusion import (FUSED_OPS, QUANT_OPS, FusedRMSQuantKey,
FusionPass, GroupShape, QuantKey)
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import (CompilationConfig, CompilationLevel, PassConfig,
VllmConfig)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_FP8_SUPPORTED, Fp8LinearOp, maybe_create_device_identity)
from vllm.platforms import current_platform

from .backend import TestBackend

FP8_DTYPE = current_platform.fp8_dtype()


class TestModel(torch.nn.Module):

def __init__(self, hidden_size: int, eps: float, static: bool,
cutlass_fp8_enabled: bool, *args, **kwargs):
super().__init__(*args, **kwargs)
self.cutlass_fp8_enabled = cutlass_fp8_enabled
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
group_shape = GroupShape.PER_TENSOR if static else GroupShape.PER_TOKEN
self.key = QuantKey(dtype=FP8_DTYPE,
static=static,
group_shape=group_shape,
symmetric=True)
if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
else:
self.scale = [None for _ in range(2)]
self.w = [
torch.rand(hidden_size, hidden_size).to(dtype=FP8_DTYPE).t()
for _ in range(2)
]
self.fp8_linear = Fp8LinearOp(
cutlass_fp8_supported=cutlass_fp8_enabled,
act_quant_static=static,
act_quant_group_shape=group_shape,
)

def forward(self, x):
resid = torch.sqrt(x)

Check failure on line 52 in tests/compile/test_fusion2.py

View workflow job for this annotation

GitHub Actions / pre-commit

Ruff (F841)

tests/compile/test_fusion2.py:52:9: F841 Local variable `resid` is assigned to but never used
y = self.norm[0](x)

return self.fp8_linear.apply(y,
self.w[0],
self.wscale[0],
input_scale=self.scale[0])

def ops_in_model_before(self):
return [QUANT_OPS[self.key]]

def ops_in_model_after(self):
return [
FUSED_OPS[FusedRMSQuantKey(self.key, False)],
]


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
@pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049])
@pytest.mark.parametrize("eps", [1e-5, 1e-6])
@pytest.mark.parametrize("static", [True, False])
@pytest.mark.parametrize("cutlass_fp8_enabled",
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
reason="Only test on CUDA and ROCm")
def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
cutlass_fp8_enabled):
torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
torch.manual_seed(1)
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths

vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm"],
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
))
with vllm.config.set_current_vllm_config(vllm_config):
# Reshape pass is needed for the fusion pass to work
noop_pass = NoOpEliminationPass(vllm_config)
fusion_pass = FusionPass.instance(vllm_config)

backend = TestBackend(noop_pass, fusion_pass)
model = TestModel(hidden_size, eps, static, cutlass_fp8_enabled)

# First dimension dynamic
x = torch.rand(num_tokens, hidden_size)
torch._dynamo.mark_dynamic(x, 0)

result = model(x)

model2 = torch.compile(model, backend=backend)
result2 = model2(x)

# Higher tol for dynamic, even higher for bfloat16
if static:
ATOL, RTOL = (1e-3, 1e-3)
elif dtype == torch.float16:
ATOL, RTOL = (2e-3, 2e-3)
else:
ATOL, RTOL = (1e-2, 1e-2)

torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)

# In pre-nodes, fp8 quant should be there and fused kernels should not
# backend.check_before_ops(model.ops_in_model_before())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This assertion is commented out. In test files, assertions should not be commented out as it can hide regressions or bugs. If this check is failing, it should be investigated and fixed. Please either re-enable this check or provide a justification for its removal.


# In post-nodes, fused kernels should be there and fp8 quant should not
backend.check_after_ops(model.ops_in_model_after())
67 changes: 38 additions & 29 deletions tests/compile/test_fusion_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
import vllm.envs as envs
from vllm.compilation.collective_fusion import AllReduceFusionPass
from vllm.compilation.fix_functionalization import FixFunctionalizationPass
from vllm.compilation.fx_utils import find_op_nodes
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
ModelConfig, PassConfig, VllmConfig)
ModelConfig, PassConfig, VllmConfig,
set_current_vllm_config)
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
Expand Down Expand Up @@ -76,25 +78,23 @@ def __init__(self, hidden_size=16, token_num=16, eps=1e-6):
self.quant_fp8 = QuantFP8(static=True,
group_shape=GroupShape.PER_TENSOR)
self.scale = torch.rand(1, dtype=torch.float32)
self.output = torch.empty((token_num, hidden_size),
dtype=torch.float32)
# self.output = torch.empty((token_num, hidden_size),
# dtype=torch.float32)

def forward(self, hidden_states, residual):
view = hidden_states.reshape(-1, self.hidden_size)
all_reduce = tensor_model_parallel_all_reduce(view)
norm_output, residual_output = self.norm(all_reduce, residual)
torch.ops._C.static_scaled_fp8_quant(self.output,
norm_output.contiguous(),
self.scale)
return self.output, residual_output
quant_out, _ = self.quant_fp8(norm_output, scale=self.scale)
return quant_out, residual_output

def ops_in_model_after(self):
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]

def ops_in_model_before(self):
return [
torch.ops.vllm.all_reduce.default,
torch.ops._C.static_scaled_fp8_quant.default
# torch.ops._C.static_scaled_fp8_quant.default
]


Expand Down Expand Up @@ -198,8 +198,7 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
initialize_model_parallel(tensor_model_parallel_size=world_size)

vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
custom_ops=["+rms_norm", "+quant_fp8"]))
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
vllm_config.compilation_config.pass_config = PassConfig(
enable_fi_allreduce_fusion=True, enable_noop=True)
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
Expand All @@ -211,22 +210,32 @@ def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
trust_remote_code=True,
dtype=dtype,
seed=42)

all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)

backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)

token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)

hidden_states = torch.randn((token_num, hidden_size), requires_grad=False)
residual = torch.randn((token_num, hidden_size), requires_grad=False)

compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states, residual)

backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after())
del all_reduce_fusion_pass
with set_current_vllm_config(vllm_config):
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
noop_pass = NoOpEliminationPass(vllm_config)
func_pass = FixFunctionalizationPass(vllm_config)

backend = TestBackend(all_reduce_fusion_pass, noop_pass, func_pass)

token_num = batch_size * seq_len
model = test_model_cls(hidden_size, token_num)

hidden_states = torch.randn((token_num, hidden_size),
requires_grad=False)
residual = torch.randn((token_num, hidden_size), requires_grad=False)

compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states, residual)

backend.check_before_ops(model.ops_in_model_before(),
fully_replaced=False)
backend.check_after_ops(model.ops_in_model_after())
print(backend.graph_pre_pass)
print(backend.graph_post_pass)
for node in find_op_nodes(
torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default,
backend.graph_post_pass):
print(f"{node.args=}")
print(f"{node.kwargs=}")

Comment on lines +233 to +240
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

These print statements and the following loop appear to be debugging code. They should be removed before merging to keep the test output clean.

del all_reduce_fusion_pass
Loading
Loading