Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ steps:
- pytest -v -s compile/test_pass_manager.py
- pytest -v -s compile/test_fusion.py
- pytest -v -s compile/test_sequence_parallelism.py
- pytest -v -s compile/test_async_tp.py

- label: PyTorch Fullgraph Smoke Test # 9min
mirror_hardwares: [amdexperimental, amdproduction]
Expand Down
18 changes: 18 additions & 0 deletions tests/compile/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from torch import fx

from vllm.compilation.fx_utils import (find_specified_fn,
find_specified_fn_maybe)
from vllm.compilation.inductor_pass import InductorPass
from vllm.config import get_current_vllm_config

Expand Down Expand Up @@ -44,3 +46,19 @@ def post_pass(self, graph: fx.Graph):
self.graph_post_pass = deepcopy(graph)
# assign by reference, will reflect the final state of the graph
self.final_graph = graph

def check_before_ops(self, ops,
find_fn=find_specified_fn, \
find_fn_maybe=find_specified_fn_maybe, \
ops_fully_replaced=True):
for op in ops:
find_fn(self.graph_pre_pass.nodes, op)
if ops_fully_replaced:
assert find_fn_maybe(self.graph_post_pass.nodes, op) is None

def check_after_ops(self, ops,
find_fn=find_specified_fn, \
find_fn_maybe=find_specified_fn_maybe):
for op in ops:
find_fn(self.graph_post_pass.nodes, op)
assert find_fn_maybe(self.graph_pre_pass.nodes, op) is None
253 changes: 253 additions & 0 deletions tests/compile/test_async_tp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,253 @@
# SPDX-License-Identifier: Apache-2.0

import json

import pytest
import torch

import vllm.envs as envs
from vllm.compilation.collective_fusion import AsyncTPPass
from vllm.config import (CompilationConfig, DeviceConfig, ModelConfig,
PassConfig, VllmConfig)
from vllm.distributed import (tensor_model_parallel_all_gather,
tensor_model_parallel_reduce_scatter)
from vllm.distributed.parallel_state import (init_distributed_environment,
initialize_model_parallel)
from vllm.platforms import current_platform
from vllm.utils import update_environment_variables

from ..models.registry import HF_EXAMPLE_MODELS
from ..utils import (compare_two_settings, create_new_process_for_each_test,
multi_gpu_test)
from .backend import TestBackend

prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]


class TestMMRSModel(torch.nn.Module):

def __init__(self, hidden_size=16):
super().__init__()
self.hidden_size = hidden_size
self.gate_proj = torch.nn.Parameter(torch.empty(
(self.hidden_size * 2, hidden_size)),
requires_grad=False)
# Initialize weights
torch.nn.init.normal_(self.gate_proj, std=0.02)

def forward(self, hidden_states):
"""
Forward pass implementing the mm + reduce scatter in the FX graph

"""
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)

# matrix multiplication
permute = self.gate_proj.permute(1, 0)
mm = torch.mm(view, permute)
reduce_scatter = tensor_model_parallel_reduce_scatter(mm, dim=0)
return reduce_scatter

def ops_in_model_before(self):
return [torch.ops.vllm.reduce_scatter.default]

def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_matmul_reduce_scatter.default]


class TestAGMMModel(torch.nn.Module):

def __init__(self, hidden_size=16):
super().__init__()
self.hidden_size = hidden_size
self.weight = torch.nn.Parameter(torch.empty(
(hidden_size, hidden_size)),
requires_grad=False)
# Initialize weights
torch.nn.init.normal_(self.weight, std=0.02)

def forward(self, hidden_states):
"""
Forward pass implementing the mm + all gather in the FX graph
"""
# Reshape input
view = hidden_states.reshape(-1, self.hidden_size)
all_gather = tensor_model_parallel_all_gather(view, dim=0)
permute = self.weight.permute(1, 0)
mm = torch.mm(all_gather, permute)
return mm

def ops_in_model_before(self):
return [torch.ops.vllm.all_gather.default]

def ops_in_model_after(self):
return [torch.ops.symm_mem.fused_all_gather_matmul.default]


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("test_model", [TestMMRSModel, TestAGMMModel])
@pytest.mark.parametrize("batch_size", [8])
@pytest.mark.parametrize("seq_len", [16])
@pytest.mark.parametrize("hidden_size", [16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
reason="Only test on CUDA")
def test_async_tp_pass_replace(test_model: str, batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
num_processes = 2

def run_torch_spawn(fn, nprocs):
# need to use torch.mp.spawn otherwise will have problems with
# torch.distributed and cuda
torch.multiprocessing.spawn(fn,
args=(num_processes, test_model,
batch_size, seq_len, hidden_size,
dtype),
nprocs=nprocs)

run_torch_spawn(async_tp_pass_on_test_model, num_processes)


def async_tp_pass_on_test_model(local_rank: int, world_size: int,
test_model_cls: torch.nn.Module,
batch_size: int, seq_len: int,
hidden_size: int, dtype: torch.dtype):
current_platform.seed_everything(0)

device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
torch.set_default_device(device)
torch.set_default_dtype(dtype)

update_environment_variables({
'RANK': str(local_rank),
'LOCAL_RANK': str(local_rank),
'WORLD_SIZE': str(world_size),
'MASTER_ADDR': 'localhost',
'MASTER_PORT': '12345',
})

# initialize distributed
init_distributed_environment()
initialize_model_parallel(tensor_model_parallel_size=world_size)

# configure vllm config for SequenceParallelismPass
vllm_config = VllmConfig()
vllm_config.compilation_config = CompilationConfig(pass_config=PassConfig(
enable_async_tp=True, ), )
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))

# this is a fake model name to construct the model config
# in the vllm_config, it's not really used.
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
vllm_config.model_config = ModelConfig(model=model_name,
task="auto",
tokenizer=model_name,
tokenizer_mode="auto",
trust_remote_code=True,
dtype=dtype,
seed=42)

async_tp_pass = AsyncTPPass(vllm_config)
backend = TestBackend(async_tp_pass)

model = test_model_cls(hidden_size)

hidden_states = torch.randn((batch_size * seq_len, hidden_size),
dtype=dtype,
requires_grad=False)

compiled_model = torch.compile(model, backend=backend)
compiled_model(hidden_states)

# In pre-nodes, all gather or reduce scatter should exist,
# fused_matmul_reduce_scatter or fused_all_gather_matmul should not
backend.check_before_ops(model.ops_in_model_before(),
ops_fully_replaced=False)

# In post-nodes, fused_matmul_reduce_scatter or \
# fused_all_gather_matmul should exist
backend.check_after_ops(model.ops_in_model_after())


@create_new_process_for_each_test()
@pytest.mark.parametrize("model_id", ["meta-llama/Llama-3.2-1B-Instruct"])
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("async_tp_enabled", [True])
@pytest.mark.parametrize("distributed_backend", ["mp"])
@pytest.mark.parametrize("eager_mode", [False, True])
def test_async_tp_pass_correctness(
model_id: str,
tp_size: int,
async_tp_enabled: bool,
distributed_backend: str,
eager_mode: bool,
num_gpus_available: int,
):
model_info = HF_EXAMPLE_MODELS.find_hf_info(model_id)
model_info.check_transformers_version(on_fail="skip")

# trust_remote_code = model_info.trust_remote_code
# tokenizer_mode = model_info.tokenizer_mode
# hf_overrides = model_info.hf_overrides
model_info.check_available_online(on_fail="skip")

pp_size = 1
if num_gpus_available < tp_size:
pytest.skip(f"Need at least {tp_size} x {pp_size} GPUs")

common_args = [
# use half precision for speed and memory savings in CI environment
"--dtype",
"float16",
"--max-model-len",
"2048",
"--max-num-seqs",
"8",
]
if eager_mode:
common_args.append("--enforce-eager")

compilation_config = {
'level': 3,
'compile_sizes': [2, 4, 8],
'splitting_ops': [],
'pass_config': {
'enable_async_tp': async_tp_enabled
},
}

async_tp_env = tp_env = {
"VLLM_USE_V1": "1",
}

aysnc_tp_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--distributed-executor-backend",
distributed_backend,
"--compilation_config",
json.dumps(compilation_config),
]

tp_args = [
*common_args,
"--tensor-parallel-size",
str(tp_size),
"--distributed-executor-backend",
"mp",
]

compare_two_settings(model_id,
aysnc_tp_args,
tp_args,
async_tp_env,
tp_env,
method="generate")
36 changes: 17 additions & 19 deletions tests/compile/test_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def __init__(self, hidden_size: int, eps: float, static: bool,
self.cutlass_fp8_enabled = cutlass_fp8_enabled
self.norm = [RMSNorm(hidden_size, eps) for _ in range(3)]
self.wscale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
self.key = QuantKey(dtype=FP8_DTYPE,
static=static,
per_tensor=static,
symmetric=True)
if static:
self.scale = [torch.rand(1, dtype=torch.float32) for _ in range(2)]
else:
Expand Down Expand Up @@ -59,6 +63,15 @@ def forward(self, x):
y3, resid = self.norm[2](x3, resid) # use resid here
return y3

def ops_in_model_before(self):
return [QUANT_OPS[self.key]]

def ops_in_model_after(self):
return [
FUSED_OPS[FusedRMSQuantKey(self.key, False)],
FUSED_OPS[FusedRMSQuantKey(self.key, True)]
]


@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [64, 3392, 4096])
Expand Down Expand Up @@ -107,25 +120,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,

torch.testing.assert_close(result, result2, atol=ATOL, rtol=RTOL)

# Check substitution worked
pre_nodes = backend.graph_pre_pass.nodes
post_nodes = backend.graph_post_pass.nodes

# static is per-tensor, dynamic is per-token
key = QuantKey(dtype=FP8_DTYPE,
static=static,
per_tensor=static,
symmetric=True)
rms_quant = FUSED_OPS[FusedRMSQuantKey(key, False)]
add_rms_quant = FUSED_OPS[FusedRMSQuantKey(key, True)]
fp8_quant = QUANT_OPS[key]

# In pre-nodes, fp8 quant should be there and fused kernels should not
assert find_auto_fn_maybe(pre_nodes, rms_quant) is None
assert find_auto_fn_maybe(pre_nodes, add_rms_quant) is None
find_auto_fn(pre_nodes, fp8_quant)
backend.check_before_ops(model.ops_in_model_before(), find_auto_fn,
find_auto_fn_maybe)

# In post-nodes, fused kernels should be there and fp8 quant should not
find_auto_fn(post_nodes, rms_quant)
find_auto_fn(post_nodes, add_rms_quant)
assert find_auto_fn_maybe(post_nodes, fp8_quant) is None
backend.check_before_ops(model.ops_in_model_after(), find_auto_fn,
find_auto_fn_maybe)
Loading