Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions tests/compile/fusions_e2e/test_tp2_async_tp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
AttentionBackendCase,
Matches,
custom_ops_combos,
is_blackwell,
)
from .models import (
FLASHINFER_ATTN,
Expand All @@ -22,6 +23,9 @@
llama4_scout_fp8,
qwen3_a3b,
)
from .models import (
llama3_8b_fp4 as llama3_8b_nvfp4,
)

pytestmark = pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")

Expand Down Expand Up @@ -90,6 +94,71 @@ def test_tp2_async_tp_fp8_fusions(
)


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
[llama3_8b_nvfp4],
)
@pytest.mark.parametrize("attn_backend", [FLASHINFER_ATTN])
@pytest.mark.parametrize("n_layers", [4])
@pytest.mark.parametrize("custom_ops", custom_ops_combos("rms_norm"))
@pytest.mark.parametrize("inductor_graph_partition", INDUCTOR_GRAPH_PARTITION)
@pytest.mark.skipif(not is_blackwell(), reason="Blackwell required for fp4")
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA")
def test_tp2_async_tp_nvfp4_fusions(
model_name: str,
matches_fn: Callable[[int], Matches],
model_kwargs: dict,
hf_overrides: Callable[[int], dict],
attn_backend: AttentionBackendCase,
n_layers: int,
custom_ops: str,
inductor_graph_partition: bool,
run_e2e_fusion_test,
):
# NVFP4 currently wires the all-gather + GEMM path only. The generic

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's set this on llama-fp4 model directly?

# reduce-scatter fusion is intentionally not reused because NVFP4 group
# scales need layout-aware sharding.
matches = matches_fn(n_layers)._replace(async_tp=n_layers * 2)

# Reduce size of model and skip weight loading time
model_kwargs["hf_overrides"] = hf_overrides(n_layers)
model_kwargs["load_format"] = "dummy"
model_kwargs["max_model_len"] = 1024
model_kwargs["kernel_config"] = {"enable_flashinfer_autotune": False}

compilation_config = dict(
use_inductor_graph_partition=inductor_graph_partition,
custom_ops=custom_ops.split(","),
pass_config=PassConfig(
fuse_act_quant=True,
fuse_attn_quant=True,
enable_sp=True,
fuse_gemm_comms=True,
fuse_allreduce_rms=False,
# Override threshold for testing (models have small hidden_size)
sp_min_token_num=512,
),
)

matches_check = [
"act_quant_fusion",
"attn_quant_fusion",
"sequence_parallel",
"async_tp",
]

run_e2e_fusion_test(
model_name,
matches,
model_kwargs,
attn_backend,
compilation_config,
matches_check,
tp_size=2,
)


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize(
"model_name, matches_fn, model_kwargs, hf_overrides",
Expand Down
242 changes: 242 additions & 0 deletions vllm/compilation/passes/fusion/collective_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,36 @@ def _flashinfer_scaled_mm_out(
)


def _flashinfer_fp4_mm_out(
A: torch.Tensor,
B: torch.Tensor,
*,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
out: torch.Tensor,
alpha: torch.Tensor,
out_dtype: torch.dtype | None = None,
use_8x4_sf_layout: bool = False,
backend: str = "cutlass",
) -> None:
from vllm.utils.flashinfer import flashinfer_scaled_fp4_mm_out

assert A.ndim == 2 and B.ndim == 2 and out.ndim == 2, (
"FlashInfer FP4 symm_mem adapter expects 2D inputs and output"
)
flashinfer_scaled_fp4_mm_out(
A,
B,
scale_a,
scale_b,
alpha,
out=out,
out_dtype=out_dtype or out.dtype,
use_8x4_sf_layout=use_8x4_sf_layout,
backend=backend,
)


def fused_flashinfer_scaled_matmul_reduce_scatter_fake(
A: torch.Tensor,
B: torch.Tensor,
Expand Down Expand Up @@ -197,6 +227,90 @@ def fused_all_gather_flashinfer_scaled_matmul(
return outputs[0]


def fused_all_gather_flashinfer_fp4_matmul_fake(
A_shard: torch.Tensor,
B: torch.Tensor,
A_scale_shard: torch.Tensor,
B_scale: torch.Tensor,
alpha: torch.Tensor,
gather_dim: int,
group_name: str,
out_dtype: torch.dtype | None = None,
view_a_scale_as_fp8: bool = False,
use_8x4_sf_layout: bool = False,
backend: str = "cutlass",
) -> torch.Tensor:
world_size = c10d._resolve_process_group(group_name).size()
output_shape = list(A_shard.shape)
output_shape[gather_dim] *= world_size
output_shape[-1] = B.shape[1]
return torch.empty(
output_shape,
dtype=out_dtype or torch.bfloat16,
device=A_shard.device,
)


def fused_all_gather_flashinfer_fp4_matmul(
A_shard: torch.Tensor,
B: torch.Tensor,
A_scale_shard: torch.Tensor,
B_scale: torch.Tensor,
alpha: torch.Tensor,
gather_dim: int,
group_name: str,
out_dtype: torch.dtype | None = None,
view_a_scale_as_fp8: bool = False,
use_8x4_sf_layout: bool = False,
backend: str = "cutlass",
) -> torch.Tensor:
assert gather_dim == 0, (
"FlashInfer FP4 symm_mem adapter currently only supports gather_dim=0"
)
assert A_shard.ndim == 2 and A_scale_shard.ndim == 2 and B.ndim == 2, (
"FlashInfer FP4 symm_mem adapter expects 2D inputs"
)
if view_a_scale_as_fp8:
A_scale_shard = A_scale_shard.view(torch.float8_e4m3fn)

group = c10d._resolve_process_group(group_name)
world_size = group.size()
output = A_shard.new_empty(
A_shard.shape[0] * world_size,
B.shape[1],
dtype=out_dtype or torch.bfloat16,
)
output_shards = output.chunk(world_size)

A = A_shard.new_empty(A_shard.shape[0] * world_size, A_shard.shape[1])
A_scale = A_scale_shard.new_empty(
Comment on lines +285 to +286

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The intermediate buffers A and A_scale are allocated using new_empty on every call to this custom op. For AsyncTP to be effective, these buffers should ideally be allocated in symmetric memory to avoid unnecessary copies during the all-gather operation. Furthermore, constant allocation of large buffers in the hot path can lead to significant performance overhead. Consider using torch.ops.symm_mem.empty_symm_mem or a similar mechanism to ensure these buffers are symmetric and potentially cached.

A_scale_shard.shape[0] * world_size,
A_scale_shard.shape[1],
)

def fp4_shard_consumer(shards: list[torch.Tensor], rank: int) -> None:
_flashinfer_fp4_mm_out(
shards[0],
B,
scale_a=shards[1],
scale_b=B_scale,
alpha=alpha,
out=output_shards[rank],
out_dtype=out_dtype,
use_8x4_sf_layout=use_8x4_sf_layout,
backend=backend,
)

torch.distributed._symmetric_memory._pipelined_multi_all_gather_and_consume(
[A_shard, A_scale_shard],
fp4_shard_consumer,
[A, A_scale],
group_name,
False,
)
return output


direct_register_custom_op(
op_name="fused_flashinfer_scaled_matmul_reduce_scatter",
op_func=fused_flashinfer_scaled_matmul_reduce_scatter,
Expand All @@ -209,6 +323,12 @@ def fused_all_gather_flashinfer_scaled_matmul(
fake_impl=fused_all_gather_flashinfer_scaled_matmul_fake,
)

direct_register_custom_op(
op_name="fused_all_gather_flashinfer_fp4_matmul",
op_func=fused_all_gather_flashinfer_fp4_matmul,
fake_impl=fused_all_gather_flashinfer_fp4_matmul_fake,
)


class BasePattern:
def __init__(self, dtype: torch.dtype, device: str | None) -> None:
Expand Down Expand Up @@ -682,6 +802,101 @@ def _replacement(
return _replacement


class FlashInferAllGatherFP4Pattern(
BasePattern, VllmPatternReplacement[..., torch.Tensor]
):
def __init__(
self,
dtype: torch.dtype,
device: str | None,
backend: str,
use_8x4_sf_layout: bool,
a_scale_view: str,
) -> None:
super().__init__(dtype, device)
self.backend = backend
self.use_8x4_sf_layout = use_8x4_sf_layout
self.a_scale_view = a_scale_view

def get_inputs(self) -> list[torch.Tensor]:
a_shard_2d = torch.empty([8, 8], device=self.device, dtype=torch.uint8)
b_2d = torch.empty([8, 16], device=self.device, dtype=torch.uint8)
a_scale_shard = torch.empty([128, 4], device=self.device, dtype=torch.int32)
b_scale = torch.empty([4, 128], device=self.device, dtype=torch.uint8)
alpha = torch.empty([], device=self.device, dtype=torch.float32)
return [
a_shard_2d,
b_2d,
a_scale_shard,
b_scale,
alpha,
]

@property
def pattern(self) -> Callable[..., torch.Tensor]:
def _pattern(
a_shard_2d: torch.Tensor,
b_2d: torch.Tensor,
a_scale_shard: torch.Tensor,
b_scale: torch.Tensor,
alpha: torch.Tensor,
) -> torch.Tensor:
all_gather_a = torch.ops.vllm.all_gather.default(
a_shard_2d,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
all_gather_a_scale = torch.ops.vllm.all_gather.default(
a_scale_shard,
dim=0,
world_size=self.tp_size,
group_name=self.tp.unique_name,
)
a_scale = all_gather_a_scale
if self.a_scale_view in ("float8", "float8_uint8"):
a_scale = torch.ops.aten.view.dtype(a_scale, torch.float8_e4m3fn)
if self.a_scale_view in ("uint8", "float8_uint8"):
a_scale = torch.ops.aten.view.dtype(a_scale, torch.uint8)
Comment on lines +857 to +860

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The double view logic for float8_uint8 is redundant. If the goal is to obtain a uint8 tensor, viewing directly as uint8 is sufficient regardless of whether it was previously viewed as float8. This simplifies the pattern and avoids unnecessary operations in the graph.

Suggested change
if self.a_scale_view in ("float8", "float8_uint8"):
a_scale = torch.ops.aten.view.dtype(a_scale, torch.float8_e4m3fn)
if self.a_scale_view in ("uint8", "float8_uint8"):
a_scale = torch.ops.aten.view.dtype(a_scale, torch.uint8)
if self.a_scale_view in ("float8", "float8_uint8"):
a_scale = torch.ops.aten.view.dtype(a_scale, torch.float8_e4m3fn)
elif self.a_scale_view == "uint8":
a_scale = torch.ops.aten.view.dtype(a_scale, torch.uint8)

return torch.ops.vllm.flashinfer_mm_fp4.default(
all_gather_a,
b_2d,
a_scale,
b_scale,
alpha,
self.dtype,
self.use_8x4_sf_layout,
self.backend,
)

return _pattern

@property
def replacement(self) -> Callable[..., torch.Tensor]:
def _replacement(
a_shard_2d: torch.Tensor,
b_2d: torch.Tensor,
a_scale_shard: torch.Tensor,
b_scale: torch.Tensor,
alpha: torch.Tensor,
) -> torch.Tensor:
return torch.ops.vllm.fused_all_gather_flashinfer_fp4_matmul.default(
a_shard_2d,
b_2d,
a_scale_shard,
b_scale,
alpha,
0,
self.tp.device_group.group_name,
self.dtype,
self.a_scale_view in ("float8", "float8_uint8"),
self.use_8x4_sf_layout,
self.backend,
)

return _replacement


class AsyncTPPass(VllmFusionPatternMatcherPass):
@enable_fake_mode
def __init__(self, config: VllmConfig) -> None:
Expand Down Expand Up @@ -718,6 +933,33 @@ def __init__(self, config: VllmConfig) -> None:
self.register(
FlashInferBMMFP8ReduceScatterPattern(self.model_dtype, self.device)
)
if hasattr(torch.ops.vllm, "flashinfer_mm_fp4"):
for backend in ("cutlass", "cudnn"):
for a_scale_view in ("float8_uint8", "uint8"):
self.register(
FlashInferAllGatherFP4Pattern(
self.model_dtype,
self.device,
backend,
use_8x4_sf_layout=False,
a_scale_view=a_scale_view,
)
)
for use_8x4_sf_layout in (False, True):
for a_scale_view in ("float8",):
self.register(
FlashInferAllGatherFP4Pattern(
self.model_dtype,
self.device,
"trtllm",
use_8x4_sf_layout=use_8x4_sf_layout,
a_scale_view=a_scale_view,
)
)
# NVFP4 activation scales are block/group scales, not FP8

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, thinking about this again, isn't reduce scatter trivial? Inputs are already column-parallel across ranks, so each rank has the appropriate scales and inputs only. Output is full size but it's activations only (and partial numerically), so reduction is needed but only on the output, no scale comms need to be involved.

Am I missing something?

# row-wise scales. Register only the all-gather path until the
# reduce-scatter side has a dedicated NVFP4 scale-sharding
# implementation.

self.dump_patterns(config, self.pm_pass)

Expand Down
Loading
Loading