Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions tensorrt_llm/_torch/compilation/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import tensorrt_llm
from tensorrt_llm import logger
from tensorrt_llm.mapping import Mapping

from .multi_stream.auto_multi_stream import multi_stream_schedule
from .patterns.ar_residual_norm import register_ar_fusions
Expand Down Expand Up @@ -39,13 +40,16 @@ def __init__(
enable_piecewise_cuda_graph: bool = False,
capture_num_tokens: Optional[List[int]] = None,
max_num_streams: int = 1,
mapping=None,
) -> None:
super().__init__()
self.elapsed_time = 0
self.module_inference_event = []
self.module_inference_time = 0
self.call_count = 0
self.custom_passes = Backend.get_custom_pass(enable_userbuffers)
self.mapping = mapping
self.custom_passes = Backend.get_custom_pass(enable_userbuffers,
mapping)
self.rank = tensorrt_llm.mpi_rank()
self.enable_inductor = enable_inductor
self.capture_num_tokens = sorted(capture_num_tokens or [])
Expand All @@ -61,8 +65,7 @@ def __init__(
self.match_count = []

@classmethod
def get_custom_pass(cls, enable_userbuffers):
# TODO: add pp + tp support
def get_custom_pass(cls, enable_userbuffers, mapping: Mapping):
world_size = tensorrt_llm.mpi_world_size()
if not cls._custom_pass_instances:
# Really naive pass manager here
Expand All @@ -73,7 +76,8 @@ def get_custom_pass(cls, enable_userbuffers):
os.environ["DISABLE_LAMPORT_REDUCE_NORM_FUSION"] = "1"
ub_enabled = enable_userbuffers and tensorrt_llm.bindings.internal.userbuffers.ub_supported(
)
register_ar_fusions(cls._custom_pass_instances, ub_enabled)
register_ar_fusions(cls._custom_pass_instances, mapping,
ub_enabled)
else:
register_add_norm(cls._custom_pass_instances[0])
return cls._custom_pass_instances
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def flatten_args(args):
for inplace_arg in inplace_map[func].values():
# At this stage, all inplace op must be using kwargs for all params
assert inplace_arg in node.kwargs
latest_inplace_stat[node.kwargs[inplace_arg]] = vertex
args = flatten_args([node.kwargs[inplace_arg]])
for arg in args:
latest_inplace_stat[arg] = vertex

for edge in in_edges.values():
edge.out_edges.append(vertex)
Expand Down
73 changes: 19 additions & 54 deletions tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,14 @@
PatternMatcherPass, fwd_only,
register_replacement)

import tensorrt_llm

from ...distributed import AllReduceFusionOp, AllReduceStrategy

aten = torch.ops.aten
from tensorrt_llm.mapping import Mapping


def register_ar_residual_norm(custom_pass: PatternMatcherPass):
# TODO: add pp + tp support
mapping = Mapping(
world_size=tensorrt_llm.mpi_world_size(),
tp_size=tensorrt_llm.mpi_world_size(),
rank=tensorrt_llm.mpi_rank(),
)
def register_ar_residual_norm(custom_pass: PatternMatcherPass,
mapping: Mapping):
residual_key = KeywordArg("residual")
trtllm_allreduce_default = CallFunction(
torch.ops.trtllm.allreduce.default, KeywordArg("input"), None, None,
Expand Down Expand Up @@ -117,14 +110,8 @@ def check_non_ub_strategy(match, strategy_node) -> bool:
return True


def register_ar_residual_norm_out_fp8_quant(custom_pass: PatternMatcherPass):
# TODO: add pp + tp support
mapping = Mapping(
world_size=tensorrt_llm.mpi_world_size(),
tp_size=tensorrt_llm.mpi_world_size(),
rank=tensorrt_llm.mpi_rank(),
)

def register_ar_residual_norm_out_fp8_quant(custom_pass: PatternMatcherPass,
mapping: Mapping):
input_node = KeywordArg("input")
strategy_node = KeywordArg("strategy")
allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default,
Expand Down Expand Up @@ -200,14 +187,8 @@ def extra_check(match: Match) -> bool:
)


def register_ar_residual_norm_fp8_quant(custom_pass: PatternMatcherPass):
# TODO: add pp + tp support
mapping = Mapping(
world_size=tensorrt_llm.mpi_world_size(),
tp_size=tensorrt_llm.mpi_world_size(),
rank=tensorrt_llm.mpi_rank(),
)

def register_ar_residual_norm_fp8_quant(custom_pass: PatternMatcherPass,
mapping: Mapping):
input_node = KeywordArg("input")
strategy_node = KeywordArg("strategy")
allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default,
Expand Down Expand Up @@ -282,14 +263,8 @@ def extra_check(match: Match) -> bool:
)


def register_ar_residual_norm_out_fp4_quant(custom_pass: PatternMatcherPass):
# TODO: add pp + tp support
mapping = Mapping(
world_size=tensorrt_llm.mpi_world_size(),
tp_size=tensorrt_llm.mpi_world_size(),
rank=tensorrt_llm.mpi_rank(),
)

def register_ar_residual_norm_out_fp4_quant(custom_pass: PatternMatcherPass,
mapping: Mapping):
input_node = KeywordArg("input")
strategy_node = KeywordArg("strategy")
allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default,
Expand Down Expand Up @@ -360,14 +335,8 @@ def extra_check(match: Match) -> bool:
)


def register_ar_residual_norm_fp4_quant(custom_pass: PatternMatcherPass):
# TODO: add pp + tp support
mapping = Mapping(
world_size=tensorrt_llm.mpi_world_size(),
tp_size=tensorrt_llm.mpi_world_size(),
rank=tensorrt_llm.mpi_rank(),
)

def register_ar_residual_norm_fp4_quant(custom_pass: PatternMatcherPass,
mapping: Mapping):
input_node = KeywordArg("input")
strategy_node = KeywordArg("strategy")
allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default,
Expand Down Expand Up @@ -437,12 +406,8 @@ def extra_check(match: Match) -> bool:
)


def register_ub_patterns(custom_passes: List[PatternMatcherPass]):
mapping = Mapping(
world_size=tensorrt_llm.mpi_world_size(),
tp_size=tensorrt_llm.mpi_world_size(),
rank=tensorrt_llm.mpi_rank(),
)
def register_ub_patterns(custom_passes: List[PatternMatcherPass],
mapping: Mapping):

def register_convert_supported_ar_to_ub(custom_pass: PatternMatcherPass):
strategy = int(AllReduceStrategy.AUTO)
Expand Down Expand Up @@ -717,16 +682,16 @@ def target_finalize_pattern(


def register_ar_fusions(custom_passes: List[PatternMatcherPass],
enable_ub: bool):
register_ar_residual_norm(custom_passes[-1])
mapping: Mapping, enable_ub: bool):
register_ar_residual_norm(custom_passes[-1], mapping)

custom_passes.append(PatternMatcherPass())
register_ar_residual_norm_fp8_quant(custom_passes[-1])
register_ar_residual_norm_fp4_quant(custom_passes[-1])
register_ar_residual_norm_fp8_quant(custom_passes[-1], mapping)
register_ar_residual_norm_fp4_quant(custom_passes[-1], mapping)
# AR-Residual-Norm-Out-Quant-X is not supported by Userbuffers kernel.
if not enable_ub:
register_ar_residual_norm_out_fp8_quant(custom_passes[-1])
register_ar_residual_norm_out_fp4_quant(custom_passes[-1])
register_ar_residual_norm_out_fp8_quant(custom_passes[-1], mapping)
register_ar_residual_norm_out_fp4_quant(custom_passes[-1], mapping)

if enable_ub:
register_ub_patterns(custom_passes)
register_ub_patterns(custom_passes, mapping)
4 changes: 4 additions & 0 deletions tensorrt_llm/_torch/compilation/piecewise_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ def call_module(self, target, args, kwargs):
dim_idx)
found_dynamic_shape = True
break
if not found_dynamic_shape:
raise RuntimeError(
"Cannot identify dynamic shape, please disable enable_piecewise_cuda_graph in TorchCompileConfig"
)

if self.max_num_streams > 1 and not self.enable_inductor:
num_events = multi_stream_schedule(submod, self.max_num_streams)
Expand Down
6 changes: 6 additions & 0 deletions tensorrt_llm/_torch/compilation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def inplace_info():
},
torch.ops.trtllm.logits_bitmask.default: {
1: "logits"
},
torch.ops.trtllm.pp_recv_tensors.default: {
1: "tensors"
},
torch.ops.trtllm.pp_send_tensors.default: {
1: "tensors"
}
}
return inplace_map
18 changes: 17 additions & 1 deletion tensorrt_llm/_torch/distributed/communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pickle # nosec B403
from abc import ABC, abstractmethod
from functools import wraps
from typing import Optional
from typing import List, Optional

import numpy as np
import torch
Expand Down Expand Up @@ -848,3 +848,19 @@ def pp_recv(tensor):
def pp_send(tensor):
"""Send tensors to next pp rank."""
_pp_comm.send(tensor)


@torch.library.custom_op("trtllm::pp_recv_tensors", mutates_args=("tensors", ))
def pp_recv_tensors(tensors: List[torch.Tensor]) -> None:
"""
Receive tensors from previous pp rank.
"""
for tensor in tensors:
pp_recv(tensor)


@torch.library.custom_op("trtllm::pp_send_tensors", mutates_args=("tensors", ))
def pp_send_tensors(tensors: List[torch.Tensor]) -> None:
"""Send tensors to next pp rank."""
for tensor in tensors:
pp_send(tensor)
12 changes: 6 additions & 6 deletions tensorrt_llm/_torch/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from ...logger import logger
from ...models.modeling_utils import QuantConfig
from ..attention_backend import AttentionMetadata
from ..distributed.communicator import pp_recv, pp_send
from ..distributed.communicator import pp_recv_tensors, pp_send_tensors
from ..model_config import ModelConfig, TConfig
from ..modules.attention import Attention
from ..modules.embedding import Embedding, LMHead
Expand Down Expand Up @@ -172,11 +172,12 @@ def forward_after_recv_fn(
residual=...,
**kwargs,
):
pp_recv(hidden_states)
if residual is not ...:
if residual is None:
residual = torch.empty_like(hidden_states)
pp_recv(residual)
pp_recv_tensors([hidden_states, residual])
else:
pp_recv_tensors([hidden_states])
return forward_fn(
position_ids,
hidden_states,
Expand Down Expand Up @@ -209,11 +210,10 @@ def forward_before_send_fn(
)
if residual is not ...:
hidden_states, residual = output
pp_send(hidden_states)
pp_send(residual)
pp_send_tensors([hidden_states, residual])
else:
hidden_states = output
pp_send(hidden_states)
pp_send_tensors([hidden_states])
return output

forward_before_send_fn.__wrapped_by_forward_before_send__ = True
Expand Down
26 changes: 19 additions & 7 deletions tensorrt_llm/_torch/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,24 @@ def pre_comm_embedding_ops(
return output


def embedding_skip_forward_impl(input: torch.Tensor, embedding_dim: int,
dtype: torch.dtype) -> torch.Tensor:
output_shape = input.shape[:] + (embedding_dim, )
output = input.new_empty(output_shape, dtype=dtype)
return output


@torch.library.custom_op("trtllm::embedding_skip_forward", mutates_args=())
def embedding_skip_forward(input: torch.Tensor, embedding_dim: int,
dtype: torch.dtype) -> torch.Tensor:
return embedding_skip_forward_impl(input, embedding_dim, dtype)


@embedding_skip_forward.register_fake
def _(input, embedding_dim, dtype):
return embedding_skip_forward_impl(input, embedding_dim, dtype)


class Embedding(LMHead):
"""Embedding layer.

Expand Down Expand Up @@ -258,10 +276,4 @@ def forward(self, input):
return output

def skip_forward(self, input):
output_shape = input.shape[:] + (self.embedding_dim, )
output = torch.empty(
output_shape,
dtype=self.dtype,
device=input.device,
)
return output
return embedding_skip_forward(input, self.embedding_dim, self.dtype)
5 changes: 3 additions & 2 deletions tensorrt_llm/_torch/pyexecutor/model_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,8 @@ def __init__(
enable_piecewise_cuda_graph=self.
_torch_compile_piecewise_cuda_graph,
capture_num_tokens=self._piecewise_cuda_graph_num_tokens,
max_num_streams=torch_compile_max_num_streams)
max_num_streams=torch_compile_max_num_streams,
mapping=self.mapping)
if isinstance(self.model, DecoderModelForCausalLM):
self.model.model = torch.compile(
self.model.model,
Expand Down Expand Up @@ -2824,7 +2825,7 @@ def _forward_step_mm_encoder_only(
return {'mm_embeddings': mm_embeddings, 'logits': None}

def _init_userbuffers(self, hidden_size):
if self.mapping.tp_size <= 1:
if self.mapping.tp_size <= 1 or self.mapping.pp_size > 1:
return False

# Disable UB for unsupported platforms
Expand Down
16 changes: 0 additions & 16 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,6 @@ def test_bfloat16(self, attn_backend, torch_compile):
ids=["tp4", "tp2pp2", "pp4"])
def test_bfloat16_4gpus(self, tp_size, pp_size, attn_backend,
torch_compile):
if torch_compile and pp_size > 1:
pytest.skip(
"Pipeline parallel with torch.compile is not supported yet.\n"
"Issue: Unfusing flashinfer_fused_add_rmsnorm causes outputs to be "
"discarded at graph breaks.")
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
enable_piecewise_cuda_graph=True,
Expand Down Expand Up @@ -1331,9 +1326,6 @@ def test_bfloat16(self, mtp_nextn, attention_dp, cuda_graph,
def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
attention_dp, cuda_graph, overlap_scheduler,
torch_compile):
if torch_compile and pp_size > 1:
pytest.skip("PP with torch.compile is not supported yet.")

if pp_size > 1 and mtp_nextn > 0:
num_hidden_layers = 30
pp_partition = [num_hidden_layers // pp_size + 1] * pp_size
Expand Down Expand Up @@ -1432,8 +1424,6 @@ def test_cute_dsl_fp8_block_scales(
overlap_scheduler,
torch_compile,
):
if torch_compile and attention_dp:
pytest.skip("https://nvbugs/5252559")
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
torch_compile_config = (TorchCompileConfig(
enable_fullgraph=True,
Expand Down Expand Up @@ -1537,8 +1527,6 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn,
def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn,
fp8kv, attention_dp, cuda_graph,
overlap_scheduler, torch_compile):
if torch_compile and pp_size > 1:
pytest.skip("PP with torch.compile is not supported yet.")
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75)
torch_compile_config = TorchCompileConfig(
enable_fullgraph=True,
Expand Down Expand Up @@ -1599,8 +1587,6 @@ def test_cute_dsl_fp8_block_scales_4gpus(
overlap_scheduler,
torch_compile,
):
if torch_compile and pp_size > 1:
pytest.skip("PP with torch.compile is not supported yet.")
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9)
torch_compile_config = (TorchCompileConfig(
enable_fullgraph=True,
Expand Down Expand Up @@ -1826,8 +1812,6 @@ def test_nvfp4_batch_waiting(self, torch_compile, fp8kv, cuda_graph,
def test_nvfp4_4gpus(self, fp8kv, attention_dp, cuda_graph,
overlap_scheduler, tp_size, pp_size, ep_size,
torch_compile, mtp_nextn, moe_backend):
if torch_compile and pp_size > 1:
pytest.skip("PP with torch.compile is not supported yet.")
if moe_backend == "TRTLLM" and (get_sm_version() == 120
or get_sm_version() == 121):
pytest.skip(
Expand Down
Loading