diff --git a/tensorrt_llm/_torch/compilation/backend.py b/tensorrt_llm/_torch/compilation/backend.py index dc77dec6696..42bf851136d 100644 --- a/tensorrt_llm/_torch/compilation/backend.py +++ b/tensorrt_llm/_torch/compilation/backend.py @@ -11,6 +11,7 @@ import tensorrt_llm from tensorrt_llm import logger +from tensorrt_llm.mapping import Mapping from .multi_stream.auto_multi_stream import multi_stream_schedule from .patterns.ar_residual_norm import register_ar_fusions @@ -39,13 +40,16 @@ def __init__( enable_piecewise_cuda_graph: bool = False, capture_num_tokens: Optional[List[int]] = None, max_num_streams: int = 1, + mapping=None, ) -> None: super().__init__() self.elapsed_time = 0 self.module_inference_event = [] self.module_inference_time = 0 self.call_count = 0 - self.custom_passes = Backend.get_custom_pass(enable_userbuffers) + self.mapping = mapping + self.custom_passes = Backend.get_custom_pass(enable_userbuffers, + mapping) self.rank = tensorrt_llm.mpi_rank() self.enable_inductor = enable_inductor self.capture_num_tokens = sorted(capture_num_tokens or []) @@ -61,8 +65,7 @@ def __init__( self.match_count = [] @classmethod - def get_custom_pass(cls, enable_userbuffers): - # TODO: add pp + tp support + def get_custom_pass(cls, enable_userbuffers, mapping: Mapping): world_size = tensorrt_llm.mpi_world_size() if not cls._custom_pass_instances: # Really naive pass manager here @@ -73,7 +76,8 @@ def get_custom_pass(cls, enable_userbuffers): os.environ["DISABLE_LAMPORT_REDUCE_NORM_FUSION"] = "1" ub_enabled = enable_userbuffers and tensorrt_llm.bindings.internal.userbuffers.ub_supported( ) - register_ar_fusions(cls._custom_pass_instances, ub_enabled) + register_ar_fusions(cls._custom_pass_instances, mapping, + ub_enabled) else: register_add_norm(cls._custom_pass_instances[0]) return cls._custom_pass_instances diff --git a/tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py b/tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py index c2d3cf012a0..bc0cd984fc5 100644 --- a/tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py +++ b/tensorrt_llm/_torch/compilation/multi_stream/auto_multi_stream.py @@ -209,7 +209,9 @@ def flatten_args(args): for inplace_arg in inplace_map[func].values(): # At this stage, all inplace op must be using kwargs for all params assert inplace_arg in node.kwargs - latest_inplace_stat[node.kwargs[inplace_arg]] = vertex + args = flatten_args([node.kwargs[inplace_arg]]) + for arg in args: + latest_inplace_stat[arg] = vertex for edge in in_edges.values(): edge.out_edges.append(vertex) diff --git a/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py b/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py index afbaa0949df..e69f903ac47 100644 --- a/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py +++ b/tensorrt_llm/_torch/compilation/patterns/ar_residual_norm.py @@ -8,21 +8,14 @@ PatternMatcherPass, fwd_only, register_replacement) -import tensorrt_llm - from ...distributed import AllReduceFusionOp, AllReduceStrategy aten = torch.ops.aten from tensorrt_llm.mapping import Mapping -def register_ar_residual_norm(custom_pass: PatternMatcherPass): - # TODO: add pp + tp support - mapping = Mapping( - world_size=tensorrt_llm.mpi_world_size(), - tp_size=tensorrt_llm.mpi_world_size(), - rank=tensorrt_llm.mpi_rank(), - ) +def register_ar_residual_norm(custom_pass: PatternMatcherPass, + mapping: Mapping): residual_key = KeywordArg("residual") trtllm_allreduce_default = CallFunction( torch.ops.trtllm.allreduce.default, KeywordArg("input"), None, None, @@ -117,14 +110,8 @@ def check_non_ub_strategy(match, strategy_node) -> bool: return True -def register_ar_residual_norm_out_fp8_quant(custom_pass: PatternMatcherPass): - # TODO: add pp + tp support - mapping = Mapping( - world_size=tensorrt_llm.mpi_world_size(), - tp_size=tensorrt_llm.mpi_world_size(), - rank=tensorrt_llm.mpi_rank(), - ) - +def register_ar_residual_norm_out_fp8_quant(custom_pass: PatternMatcherPass, + mapping: Mapping): input_node = KeywordArg("input") strategy_node = KeywordArg("strategy") allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default, @@ -200,14 +187,8 @@ def extra_check(match: Match) -> bool: ) -def register_ar_residual_norm_fp8_quant(custom_pass: PatternMatcherPass): - # TODO: add pp + tp support - mapping = Mapping( - world_size=tensorrt_llm.mpi_world_size(), - tp_size=tensorrt_llm.mpi_world_size(), - rank=tensorrt_llm.mpi_rank(), - ) - +def register_ar_residual_norm_fp8_quant(custom_pass: PatternMatcherPass, + mapping: Mapping): input_node = KeywordArg("input") strategy_node = KeywordArg("strategy") allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default, @@ -282,14 +263,8 @@ def extra_check(match: Match) -> bool: ) -def register_ar_residual_norm_out_fp4_quant(custom_pass: PatternMatcherPass): - # TODO: add pp + tp support - mapping = Mapping( - world_size=tensorrt_llm.mpi_world_size(), - tp_size=tensorrt_llm.mpi_world_size(), - rank=tensorrt_llm.mpi_rank(), - ) - +def register_ar_residual_norm_out_fp4_quant(custom_pass: PatternMatcherPass, + mapping: Mapping): input_node = KeywordArg("input") strategy_node = KeywordArg("strategy") allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default, @@ -360,14 +335,8 @@ def extra_check(match: Match) -> bool: ) -def register_ar_residual_norm_fp4_quant(custom_pass: PatternMatcherPass): - # TODO: add pp + tp support - mapping = Mapping( - world_size=tensorrt_llm.mpi_world_size(), - tp_size=tensorrt_llm.mpi_world_size(), - rank=tensorrt_llm.mpi_rank(), - ) - +def register_ar_residual_norm_fp4_quant(custom_pass: PatternMatcherPass, + mapping: Mapping): input_node = KeywordArg("input") strategy_node = KeywordArg("strategy") allreduce_default = CallFunction(torch.ops.trtllm.allreduce.default, @@ -437,12 +406,8 @@ def extra_check(match: Match) -> bool: ) -def register_ub_patterns(custom_passes: List[PatternMatcherPass]): - mapping = Mapping( - world_size=tensorrt_llm.mpi_world_size(), - tp_size=tensorrt_llm.mpi_world_size(), - rank=tensorrt_llm.mpi_rank(), - ) +def register_ub_patterns(custom_passes: List[PatternMatcherPass], + mapping: Mapping): def register_convert_supported_ar_to_ub(custom_pass: PatternMatcherPass): strategy = int(AllReduceStrategy.AUTO) @@ -717,16 +682,16 @@ def target_finalize_pattern( def register_ar_fusions(custom_passes: List[PatternMatcherPass], - enable_ub: bool): - register_ar_residual_norm(custom_passes[-1]) + mapping: Mapping, enable_ub: bool): + register_ar_residual_norm(custom_passes[-1], mapping) custom_passes.append(PatternMatcherPass()) - register_ar_residual_norm_fp8_quant(custom_passes[-1]) - register_ar_residual_norm_fp4_quant(custom_passes[-1]) + register_ar_residual_norm_fp8_quant(custom_passes[-1], mapping) + register_ar_residual_norm_fp4_quant(custom_passes[-1], mapping) # AR-Residual-Norm-Out-Quant-X is not supported by Userbuffers kernel. if not enable_ub: - register_ar_residual_norm_out_fp8_quant(custom_passes[-1]) - register_ar_residual_norm_out_fp4_quant(custom_passes[-1]) + register_ar_residual_norm_out_fp8_quant(custom_passes[-1], mapping) + register_ar_residual_norm_out_fp4_quant(custom_passes[-1], mapping) if enable_ub: - register_ub_patterns(custom_passes) + register_ub_patterns(custom_passes, mapping) diff --git a/tensorrt_llm/_torch/compilation/piecewise_optimizer.py b/tensorrt_llm/_torch/compilation/piecewise_optimizer.py index 67d83c3c43f..bca4a4715ea 100644 --- a/tensorrt_llm/_torch/compilation/piecewise_optimizer.py +++ b/tensorrt_llm/_torch/compilation/piecewise_optimizer.py @@ -79,6 +79,10 @@ def call_module(self, target, args, kwargs): dim_idx) found_dynamic_shape = True break + if not found_dynamic_shape: + raise RuntimeError( + "Cannot identify dynamic shape, please disable enable_piecewise_cuda_graph in TorchCompileConfig" + ) if self.max_num_streams > 1 and not self.enable_inductor: num_events = multi_stream_schedule(submod, self.max_num_streams) diff --git a/tensorrt_llm/_torch/compilation/utils.py b/tensorrt_llm/_torch/compilation/utils.py index f74ce3f1d4a..39ab46de402 100644 --- a/tensorrt_llm/_torch/compilation/utils.py +++ b/tensorrt_llm/_torch/compilation/utils.py @@ -76,6 +76,12 @@ def inplace_info(): }, torch.ops.trtllm.logits_bitmask.default: { 1: "logits" + }, + torch.ops.trtllm.pp_recv_tensors.default: { + 1: "tensors" + }, + torch.ops.trtllm.pp_send_tensors.default: { + 1: "tensors" } } return inplace_map diff --git a/tensorrt_llm/_torch/distributed/communicator.py b/tensorrt_llm/_torch/distributed/communicator.py index 3e8d1779679..67790b240a6 100644 --- a/tensorrt_llm/_torch/distributed/communicator.py +++ b/tensorrt_llm/_torch/distributed/communicator.py @@ -3,7 +3,7 @@ import pickle # nosec B403 from abc import ABC, abstractmethod from functools import wraps -from typing import Optional +from typing import List, Optional import numpy as np import torch @@ -848,3 +848,19 @@ def pp_recv(tensor): def pp_send(tensor): """Send tensors to next pp rank.""" _pp_comm.send(tensor) + + +@torch.library.custom_op("trtllm::pp_recv_tensors", mutates_args=("tensors", )) +def pp_recv_tensors(tensors: List[torch.Tensor]) -> None: + """ + Receive tensors from previous pp rank. + """ + for tensor in tensors: + pp_recv(tensor) + + +@torch.library.custom_op("trtllm::pp_send_tensors", mutates_args=("tensors", )) +def pp_send_tensors(tensors: List[torch.Tensor]) -> None: + """Send tensors to next pp rank.""" + for tensor in tensors: + pp_send(tensor) diff --git a/tensorrt_llm/_torch/models/modeling_utils.py b/tensorrt_llm/_torch/models/modeling_utils.py index d17bcab2df4..ddf253d41e7 100755 --- a/tensorrt_llm/_torch/models/modeling_utils.py +++ b/tensorrt_llm/_torch/models/modeling_utils.py @@ -18,7 +18,7 @@ from ...logger import logger from ...models.modeling_utils import QuantConfig from ..attention_backend import AttentionMetadata -from ..distributed.communicator import pp_recv, pp_send +from ..distributed.communicator import pp_recv_tensors, pp_send_tensors from ..model_config import ModelConfig, TConfig from ..modules.attention import Attention from ..modules.embedding import Embedding, LMHead @@ -172,11 +172,12 @@ def forward_after_recv_fn( residual=..., **kwargs, ): - pp_recv(hidden_states) if residual is not ...: if residual is None: residual = torch.empty_like(hidden_states) - pp_recv(residual) + pp_recv_tensors([hidden_states, residual]) + else: + pp_recv_tensors([hidden_states]) return forward_fn( position_ids, hidden_states, @@ -209,11 +210,10 @@ def forward_before_send_fn( ) if residual is not ...: hidden_states, residual = output - pp_send(hidden_states) - pp_send(residual) + pp_send_tensors([hidden_states, residual]) else: hidden_states = output - pp_send(hidden_states) + pp_send_tensors([hidden_states]) return output forward_before_send_fn.__wrapped_by_forward_before_send__ = True diff --git a/tensorrt_llm/_torch/modules/embedding.py b/tensorrt_llm/_torch/modules/embedding.py index 046954da46d..09b6ad0d0dd 100644 --- a/tensorrt_llm/_torch/modules/embedding.py +++ b/tensorrt_llm/_torch/modules/embedding.py @@ -180,6 +180,24 @@ def pre_comm_embedding_ops( return output +def embedding_skip_forward_impl(input: torch.Tensor, embedding_dim: int, + dtype: torch.dtype) -> torch.Tensor: + output_shape = input.shape[:] + (embedding_dim, ) + output = input.new_empty(output_shape, dtype=dtype) + return output + + +@torch.library.custom_op("trtllm::embedding_skip_forward", mutates_args=()) +def embedding_skip_forward(input: torch.Tensor, embedding_dim: int, + dtype: torch.dtype) -> torch.Tensor: + return embedding_skip_forward_impl(input, embedding_dim, dtype) + + +@embedding_skip_forward.register_fake +def _(input, embedding_dim, dtype): + return embedding_skip_forward_impl(input, embedding_dim, dtype) + + class Embedding(LMHead): """Embedding layer. @@ -258,10 +276,4 @@ def forward(self, input): return output def skip_forward(self, input): - output_shape = input.shape[:] + (self.embedding_dim, ) - output = torch.empty( - output_shape, - dtype=self.dtype, - device=input.device, - ) - return output + return embedding_skip_forward(input, self.embedding_dim, self.dtype) diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index 0e104185bc2..f4eeab954c5 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -283,7 +283,8 @@ def __init__( enable_piecewise_cuda_graph=self. _torch_compile_piecewise_cuda_graph, capture_num_tokens=self._piecewise_cuda_graph_num_tokens, - max_num_streams=torch_compile_max_num_streams) + max_num_streams=torch_compile_max_num_streams, + mapping=self.mapping) if isinstance(self.model, DecoderModelForCausalLM): self.model.model = torch.compile( self.model.model, @@ -2824,7 +2825,7 @@ def _forward_step_mm_encoder_only( return {'mm_embeddings': mm_embeddings, 'logits': None} def _init_userbuffers(self, hidden_size): - if self.mapping.tp_size <= 1: + if self.mapping.tp_size <= 1 or self.mapping.pp_size > 1: return False # Disable UB for unsupported platforms diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 75d636f130c..af5fb327762 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -122,11 +122,6 @@ def test_bfloat16(self, attn_backend, torch_compile): ids=["tp4", "tp2pp2", "pp4"]) def test_bfloat16_4gpus(self, tp_size, pp_size, attn_backend, torch_compile): - if torch_compile and pp_size > 1: - pytest.skip( - "Pipeline parallel with torch.compile is not supported yet.\n" - "Issue: Unfusing flashinfer_fused_add_rmsnorm causes outputs to be " - "discarded at graph breaks.") torch_compile_config = TorchCompileConfig( enable_fullgraph=True, enable_piecewise_cuda_graph=True, @@ -1331,9 +1326,6 @@ def test_bfloat16(self, mtp_nextn, attention_dp, cuda_graph, def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn, attention_dp, cuda_graph, overlap_scheduler, torch_compile): - if torch_compile and pp_size > 1: - pytest.skip("PP with torch.compile is not supported yet.") - if pp_size > 1 and mtp_nextn > 0: num_hidden_layers = 30 pp_partition = [num_hidden_layers // pp_size + 1] * pp_size @@ -1432,8 +1424,6 @@ def test_cute_dsl_fp8_block_scales( overlap_scheduler, torch_compile, ): - if torch_compile and attention_dp: - pytest.skip("https://nvbugs/5252559") kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9) torch_compile_config = (TorchCompileConfig( enable_fullgraph=True, @@ -1537,8 +1527,6 @@ def test_fp8_block_scales_cuda_graph_padding_4gpus(self, mtp_nextn, def test_fp8_block_scales_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn, fp8kv, attention_dp, cuda_graph, overlap_scheduler, torch_compile): - if torch_compile and pp_size > 1: - pytest.skip("PP with torch.compile is not supported yet.") kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.75) torch_compile_config = TorchCompileConfig( enable_fullgraph=True, @@ -1599,8 +1587,6 @@ def test_cute_dsl_fp8_block_scales_4gpus( overlap_scheduler, torch_compile, ): - if torch_compile and pp_size > 1: - pytest.skip("PP with torch.compile is not supported yet.") kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.9) torch_compile_config = (TorchCompileConfig( enable_fullgraph=True, @@ -1826,8 +1812,6 @@ def test_nvfp4_batch_waiting(self, torch_compile, fp8kv, cuda_graph, def test_nvfp4_4gpus(self, fp8kv, attention_dp, cuda_graph, overlap_scheduler, tp_size, pp_size, ep_size, torch_compile, mtp_nextn, moe_backend): - if torch_compile and pp_size > 1: - pytest.skip("PP with torch.compile is not supported yet.") if moe_backend == "TRTLLM" and (get_sm_version() == 120 or get_sm_version() == 121): pytest.skip( diff --git a/tests/unittest/_torch/multi_gpu/test_ar_residual_norm.py b/tests/unittest/_torch/multi_gpu/test_ar_residual_norm.py index 2d97dcfbb94..34a7cf5a1ea 100644 --- a/tests/unittest/_torch/multi_gpu/test_ar_residual_norm.py +++ b/tests/unittest/_torch/multi_gpu/test_ar_residual_norm.py @@ -66,7 +66,9 @@ def row_linear_residual_norm_fusion_forward( x: torch.Tensor, residual: torch.Tensor, hidden_size: int, dtype: torch.dtype, tensor_parallel_size: int, tensor_parallel_rank: int, weights: torch.Tensor, fused_add_norm: bool): - backend = Backend() + backend = Backend(mapping=Mapping(world_size=tensor_parallel_size, + tp_size=tensor_parallel_size, + rank=tensor_parallel_rank)) x = x.cuda() residual = residual.cuda() norm_weight = torch.randn((hidden_size, ), dtype=dtype, device="cuda") diff --git a/tests/unittest/_torch/multi_gpu/test_user_buffers.py b/tests/unittest/_torch/multi_gpu/test_user_buffers.py index 340b2ea6282..c547c8a3e89 100644 --- a/tests/unittest/_torch/multi_gpu/test_user_buffers.py +++ b/tests/unittest/_torch/multi_gpu/test_user_buffers.py @@ -33,6 +33,14 @@ pytestmark = pytest.mark.threadleak(enabled=False) +def create_tp_mapping(tp_size, rank): + return Mapping( + world_size=tp_size, + tp_size=tp_size, + rank=rank, + ) + + def init_userbuffers_allocator(tp_size, rank, max_ub_size): ub.initialize_userbuffers_manager(tp_size, 1, 1, rank, torch.cuda.device_count(), max_ub_size, @@ -126,11 +134,7 @@ def run_single_rank_ar_rms_norm(tensor_parallel_size, a, b, c, gamma): ub0_tensor = create_userbuffers_tensor(c.size(), a.dtype) hidden = torch.matmul(a_local, b_local, out=ub0_tensor) - mapping = Mapping( - world_size=tensor_parallel_size, - tp_size=tensor_parallel_size, - rank=rank, - ) + mapping = create_tp_mapping(tensor_parallel_size, rank) ar = AllReduce(mapping=mapping, strategy=AllReduceStrategy.UB) ar_params = AllReduceParams( strategy=AllReduceStrategy.UB, @@ -217,11 +221,8 @@ def run_single_rank_ar_rms_norm_fp8(tensor_parallel_size, a, b, c, gamma, ub0_tensor = create_userbuffers_tensor(c.size(), a.dtype) hidden = torch.matmul(a_local, b_local, out=ub0_tensor) - mapping = Mapping( - world_size=tensor_parallel_size, - tp_size=tensor_parallel_size, - rank=rank, - ) + mapping = create_tp_mapping(tensor_parallel_size, rank) + ar = AllReduce(mapping=mapping, strategy=AllReduceStrategy.UB) ar_params = AllReduceParams( strategy=AllReduceStrategy.UB, @@ -322,11 +323,8 @@ def __init__(self, tp_size, rank, hidden_size, dtype, eps, l0_weight, quant_config.layer_quant_mode self.rank = rank self.tp_size = tp_size - mapping = Mapping( - world_size=tp_size, - tp_size=tp_size, - rank=rank, - ) + mapping = create_tp_mapping(tp_size, rank) + self.l0 = Linear(in_features=hidden_size, out_features=hidden_size, bias=False, @@ -452,7 +450,9 @@ def run_single_rank_ub_pass( quant(l3_weight, l3_weight_scale), l3_input_scale, l3_weight_scale, quant(l4_weight, l4_weight_scale), l4_input_scale, l4_weight_scale, norm0_gamma, norm1_gamma, norm2_gamma) - backend = Backend(enable_inductor=False, enable_userbuffers=True) + backend = Backend(enable_inductor=False, + enable_userbuffers=True, + mapping=create_tp_mapping(tensor_parallel_size, rank)) model_opt = torch.compile(model, backend=backend, fullgraph=True) with torch.inference_mode(): output_fused = model_opt(input) @@ -602,11 +602,7 @@ def run_single_rank_ar_rms_norm_fp4(tensor_parallel_size, a, b, c, gamma): ub0_tensor = create_userbuffers_tensor(c.size(), a.dtype) hidden = torch.matmul(a_local, b_local, out=ub0_tensor) - mapping = Mapping( - world_size=tensor_parallel_size, - tp_size=tensor_parallel_size, - rank=rank, - ) + mapping = create_tp_mapping(tensor_parallel_size, rank) ar = AllReduce(mapping=mapping, strategy=AllReduceStrategy.UB) ar_params = AllReduceParams( strategy=AllReduceStrategy.UB, @@ -688,11 +684,7 @@ def __init__(self, tp_size, rank, hidden_size, dtype, eps, norm0_gamma, self.rank = rank self.hidden_size = hidden_size self.dtype = dtype - mapping = Mapping( - world_size=tp_size, - tp_size=tp_size, - rank=rank, - ) + mapping = create_tp_mapping(tp_size, rank) self.ar_0 = AllReduce(mapping=mapping).cuda() self.ar_1 = AllReduce(mapping=mapping).cuda() self.ar_2 = AllReduce(mapping=mapping).cuda() @@ -748,7 +740,9 @@ def run_single_rank_ub_mm_add_pass(tensor_parallel_size, num_tokens, init_userbuffers_allocator(tensor_parallel_size, rank, ub_size) model = UBMMAddModel(tensor_parallel_size, rank, hidden_size, dtype, eps, norm0_gamma, norm1_gamma, norm2_gamma) - backend = Backend(enable_inductor=False, enable_userbuffers=True) + backend = Backend(enable_inductor=False, + enable_userbuffers=True, + mapping=create_tp_mapping(tensor_parallel_size, rank)) model_opt = torch.compile(model, backend=backend, fullgraph=True) with torch.inference_mode(): output_fused = model_opt(mm0_input_0, mm0_input_1, mm1_input_0, @@ -819,59 +813,40 @@ def __init__(self, tp_size, rank, hidden_size, dtype, eps, l0_weight, quant_config.layer_quant_mode self.rank = rank self.tp_size = tp_size + mapping = create_tp_mapping(tp_size, rank) self.l0 = Linear(in_features=hidden_size, out_features=hidden_size, bias=False, dtype=dtype, - mapping=Mapping( - world_size=tp_size, - tp_size=tp_size, - rank=rank, - ), + mapping=mapping, tensor_parallel_mode=TensorParallelMode.ROW, quant_config=quant_config).cuda() self.l1 = Linear(in_features=hidden_size, out_features=hidden_size, bias=False, dtype=dtype, - mapping=Mapping( - world_size=tp_size, - tp_size=tp_size, - rank=rank, - ), + mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, quant_config=quant_config).cuda() self.l2 = Linear(in_features=hidden_size, out_features=hidden_size, bias=False, dtype=dtype, - mapping=Mapping( - world_size=tp_size, - tp_size=tp_size, - rank=rank, - ), + mapping=mapping, tensor_parallel_mode=TensorParallelMode.ROW, quant_config=quant_config).cuda() self.l3 = Linear(in_features=hidden_size, out_features=hidden_size, bias=False, dtype=dtype, - mapping=Mapping( - world_size=tp_size, - tp_size=tp_size, - rank=rank, - ), + mapping=mapping, tensor_parallel_mode=TensorParallelMode.COLUMN, quant_config=quant_config).cuda() self.l4 = Linear(in_features=hidden_size, out_features=hidden_size, bias=False, dtype=dtype, - mapping=Mapping( - world_size=tp_size, - tp_size=tp_size, - rank=rank, - ), + mapping=mapping, tensor_parallel_mode=TensorParallelMode.ROW, quant_config=quant_config).cuda() self.norm0 = RMSNorm(hidden_size=hidden_size, eps=eps, @@ -1006,7 +981,9 @@ def block_scale_unswizzled(scale): l4_weight_scale_block_unswizzled.view(torch.float8_e4m3fn), l4_weight_scale, norm0_gamma, norm1_gamma, norm2_gamma) - backend = Backend(enable_inductor=False, enable_userbuffers=True) + backend = Backend(enable_inductor=False, + enable_userbuffers=True, + mapping=create_tp_mapping(tensor_parallel_size, rank)) model_opt = torch.compile(model, backend=backend, fullgraph=True) with torch.inference_mode(): output_ref = model(input)