Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
496 changes: 496 additions & 0 deletions examples/distributed_inference/llama3_model.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions examples/distributed_inference/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,20 +84,20 @@ def parallel_rotary_block(rotary_block, tp_mesh):
"wk": ColwiseParallel(),
"wo": RowwiseParallel(output_layouts=Shard(0)),
}
rotary_block.n_parallel = 1 # this is for single GPU, to do remove this hardcode
rotary_block.n_parallel = tp_mesh.size()

parallelize_module(rotary_block, tp_mesh, plan)


class RotaryAttention(nn.Module):
def __init__(self, dim: int, seq_len: int):
def __init__(self, dim: int, seq_len: int, n_parallel: int = 1):
super().__init__()
self.dim = dim
self.wq = nn.Linear(dim, dim)
self.wk = nn.Linear(dim, dim)
self.wo = nn.Linear(dim, dim)
self.seq_len = seq_len
self.n_parallel = 1
self.n_parallel = n_parallel
self.register_buffer("freqs_cis", self._precompute_freqs_cis(), persistent=True)
self.init_weights()

Expand Down
72 changes: 72 additions & 0 deletions examples/distributed_inference/tensor_parallel_llama3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Taken and modified pytorch lightening
# https://lightning.ai/lightning-ai/studios/tensor-parallelism-supercharging-large-model-training-with-pytorch-lightning
import logging
import os
import time

import torch
import torch.distributed as dist
from llama3_model import ModelArgs, ParallelTransformer
from tensor_parallel_initialize_dist import (
cleanup_distributed_env,
initialize_distributed_env,
)
from torch.distributed._composable.fsdp import MixedPrecisionPolicy
from torch.distributed._composable.fsdp.fully_shard import fully_shard
from torch.distributed._tensor import Replicate, Shard
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
)

if not dist.is_initialized():
initialize_distributed_env()

import torch_tensorrt
from torch_tensorrt.dynamo.distributed.utils import (
get_tensor_parallel_device_mesh,
initialize_distributed_logger,
)

device_mesh, _world_size, _rank = get_tensor_parallel_device_mesh()
logger = initialize_distributed_logger(_rank, "tensor_parallel_llama3")

logger.info(f"Starting PyTorch TP example on rank {_rank}.")
assert (
_world_size % 2 == 0
), f"TP examples require even number of GPUs, but got {_world_size} gpus"

model_args = ModelArgs(
vocab_size=32000,
dim=1024,
n_layers=4,
n_heads=8,
rope_theta=500000.0,
n_kv_heads=8,
device="cuda",
)

with torch.no_grad():
model = ParallelTransformer(model_args, device_mesh)
torch.manual_seed(0)
inp = torch.randint(32000, (8, 256), device="cuda")
python_result = model(inp)
torch_tensorrt.runtime.set_multi_device_safe_mode(True)
model = torch.compile(
model,
fullgraph=True,
backend="torch_tensorrt",
options={
"use_python_runtime": True,
"use_distributed_mode_trace": True,
"debug": True,
},
dynamic=False,
)

start = time.time()
output = model(inp)
end = time.time()
logger.info(f"Compilation time is {end-start}")
assert (python_result - output).std() < 0.01, "Compilation result is not correct."

cleanup_distributed_env()
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
DIM = 128

with torch.no_grad():
model = RotaryAttention(DIM, SEQ_LEN)
model = RotaryAttention(DIM, SEQ_LEN, device_mesh.size())
parallel_rotary_block(model, device_mesh)
device = torch.device("cuda", device_mesh.get_rank())
model.to(device)
Expand Down
9 changes: 1 addition & 8 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import EngineCapability, dtype
from torch_tensorrt._features import needs_cross_compile
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults, partitioning
from torch_tensorrt.dynamo._DryRunTracker import (
DryRunTracker,
Expand Down Expand Up @@ -287,7 +286,6 @@ def cross_compile_for_windows(
arg_inputs = [arg_inputs] # type: ignore

# Prepare torch_trt inputs
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
device = to_torch_tensorrt_device(device)
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
Expand Down Expand Up @@ -377,7 +375,6 @@ def cross_compile_for_windows(
)
trt_gm = compile_module(
gm,
trt_arg_inputs,
trt_kwarg_inputs,
settings,
)
Expand Down Expand Up @@ -623,7 +620,6 @@ def compile(
arg_inputs = [arg_inputs] # type: ignore

# Prepare torch_trt inputs
trt_arg_inputs: Sequence[Input] = prepare_inputs(arg_inputs)
trt_kwarg_inputs: Optional[dict[Any, Any]] = prepare_inputs(kwarg_inputs)
device = to_torch_tensorrt_device(device)
enabled_precisions = {dtype._from(p) for p in enabled_precisions}
Expand Down Expand Up @@ -709,16 +705,13 @@ def compile(
logger.warning(
"Remaining GPU memory may not be enough to compile the TensorRT engine for this model resulting in an OOM error, Consider setting offload_module_to_cpu=True"
)
trt_gm = compile_module(
gm, trt_arg_inputs, trt_kwarg_inputs, settings, engine_cache
)
trt_gm = compile_module(gm, trt_kwarg_inputs, settings, engine_cache)
return trt_gm


@fn_supports_debugger # type: ignore[misc]
def compile_module(
gm: torch.fx.GraphModule,
sample_arg_inputs: Sequence[Input],
sample_kwarg_inputs: Optional[dict[Any, Any]] = None,
settings: CompilationSettings = CompilationSettings(),
engine_cache: Optional[BaseEngineCache] = None,
Expand Down
5 changes: 0 additions & 5 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
)
from torch_tensorrt.dynamo.utils import (
parse_dynamo_kwargs,
prepare_inputs,
set_log_level,
)

Expand Down Expand Up @@ -150,9 +149,6 @@ def _pretraced_backend(

logger.debug("Lowered Input graph:\n " + str(gm.graph))

torchtrt_inputs = prepare_inputs(
torch_inputs, disable_memory_format_check=True
)
if settings.require_full_compilation:
logger.warning(
"require_full_compilation arg is not applicable for torch.compile with backend='torch_tensorrt"
Expand All @@ -163,7 +159,6 @@ def _pretraced_backend(
)
trt_compiled = compile_module(
gm,
torchtrt_inputs,
settings=settings,
engine_cache=engine_cache,
)
Expand Down

This file was deleted.

Loading
Loading