Skip to content

Commit 7caeb6e

Browse files
committed
Support SDXL and its distributed inference
1 parent 3c46c27 commit 7caeb6e

17 files changed

+2374
-132
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Stable Diffusion XL
2+
3+
This document elaborates how to build the [SDXL](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) model to runnable engines on single or multiple GPUs and perform a image generation task using these engines.
4+
5+
The design of distributed parallel inference comes from the CVPR 2024 paper [Distrifusion](https://github.com/mit-han-lab/distrifuser). In order to reduce the difficulty of implementation, all communications in the example are synchronous.
6+
7+
## Usage
8+
9+
### 1. Build TensorRT Engine(s)
10+
11+
```bash
12+
# 1 gpu
13+
python build_sdxl_unet.py --size 1024
14+
15+
# 2 gpus
16+
mpirun -n 2 python build_sdxl_unet.py --size 1024
17+
```
18+
19+
### 2. Generate images using the engine(s)
20+
21+
22+
```bash
23+
# 1 gpu
24+
python run_sdxl.py --size 1024 --prompt "flowers, rabbit"
25+
26+
# 2 gpus
27+
mpirun -n 2 python run_sdxl.py --size 1024 --prompt "flowers, rabbit"
28+
```
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
import argparse
2+
import os
3+
4+
import tensorrt as trt
5+
import torch
6+
from diffusers import DiffusionPipeline
7+
8+
import tensorrt_llm
9+
from tensorrt_llm.builder import Builder
10+
from tensorrt_llm.mapping import Mapping
11+
from tensorrt_llm.models.unet.pp.unet_pp import DistriUNetPP
12+
from tensorrt_llm.models.unet.unet_2d_condition import UNet2DConditionModel
13+
from tensorrt_llm.models.unet.weights import load_from_hf_unet
14+
from tensorrt_llm.network import net_guard
15+
16+
parser = argparse.ArgumentParser(description='build the UNet TensorRT engine.')
17+
parser.add_argument('--size', type=int, default=1024, help='image size')
18+
parser.add_argument('--output_dir',
19+
type=str,
20+
default=None,
21+
help='output directory')
22+
23+
args = parser.parse_args()
24+
25+
size = args.size
26+
sample_size = size // 8
27+
28+
world_size = tensorrt_llm.mpi_world_size()
29+
rank = tensorrt_llm.mpi_rank()
30+
output_dir = f'sdxl_s{size}_w{world_size}' if args.output_dir is None else args.output_dir
31+
if rank == 0 and not os.path.exists(output_dir):
32+
os.makedirs(output_dir)
33+
34+
device_per_batch = world_size // 2 if world_size > 1 else 1
35+
batch_group = 2 if world_size > 1 else 1
36+
37+
# Use tp_size to indicate the size of patch parallelism
38+
# Use pp_size to indicate the size of batch parallelism
39+
mapping = Mapping(world_size=world_size,
40+
rank=rank,
41+
tp_size=device_per_batch,
42+
pp_size=batch_group)
43+
44+
torch.cuda.set_device(tensorrt_llm.mpi_rank())
45+
46+
tensorrt_llm.logger.set_level('verbose')
47+
builder = Builder()
48+
builder_config = builder.create_builder_config(
49+
name='UNet2DConditionModel',
50+
precision='float16',
51+
timing_cache='model.cache',
52+
profiling_verbosity='detailed',
53+
tensor_parallel=world_size,
54+
precision_constraints=
55+
None, # do not use obey or the precision error will be too large
56+
)
57+
58+
pipeline = DiffusionPipeline.from_pretrained(
59+
"stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
60+
model = UNet2DConditionModel(
61+
sample_size=sample_size,
62+
in_channels=4,
63+
out_channels=4,
64+
center_input_sample=False,
65+
flip_sin_to_cos=True,
66+
freq_shift=0,
67+
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D",
68+
"CrossAttnDownBlock2D"),
69+
up_block_types=("CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"),
70+
block_out_channels=(320, 640, 1280),
71+
layers_per_block=2,
72+
downsample_padding=1,
73+
mid_block_scale_factor=1.0,
74+
act_fn="silu",
75+
norm_num_groups=32,
76+
norm_eps=1e-5,
77+
cross_attention_dim=2048,
78+
attention_head_dim=[5, 10, 20],
79+
addition_embed_type="text_time",
80+
addition_time_embed_dim=256,
81+
projection_class_embeddings_input_dim=2816,
82+
transformer_layers_per_block=[1, 2, 10],
83+
use_linear_projection=True,
84+
dtype=trt.float16,
85+
)
86+
87+
load_from_hf_unet(pipeline.unet, model)
88+
model = DistriUNetPP(model, mapping)
89+
90+
# Module -> Network
91+
network = builder.create_network()
92+
network.plugin_config.to_legacy_setting()
93+
if mapping.world_size > 1:
94+
network.plugin_config.set_nccl_plugin('float16')
95+
96+
with net_guard(network):
97+
# Prepare
98+
network.set_named_parameters(model.named_parameters())
99+
100+
# Forward
101+
sample = tensorrt_llm.Tensor(
102+
name='sample',
103+
dtype=trt.float16,
104+
shape=[2, 4, sample_size, sample_size],
105+
)
106+
timesteps = tensorrt_llm.Tensor(
107+
name='timesteps',
108+
dtype=trt.float16,
109+
shape=[
110+
1,
111+
],
112+
)
113+
encoder_hidden_states = tensorrt_llm.Tensor(
114+
name='encoder_hidden_states',
115+
dtype=trt.float16,
116+
shape=[2, 77, 2048],
117+
)
118+
text_embeds = tensorrt_llm.Tensor(
119+
name='text_embeds',
120+
dtype=trt.float16,
121+
shape=[2, 1280],
122+
)
123+
time_ids = tensorrt_llm.Tensor(
124+
name='time_ids',
125+
dtype=trt.float16,
126+
shape=[2, 6],
127+
)
128+
129+
output = model(sample, timesteps, encoder_hidden_states, text_embeds,
130+
time_ids)
131+
132+
# Mark outputs
133+
output_dtype = trt.float16
134+
output.mark_output('pred', output_dtype)
135+
136+
# Network -> Engine
137+
engine = builder.build_engine(network, builder_config)
138+
assert engine is not None, 'Failed to build engine.'
139+
140+
engine_name = f'sdxl_unet_s{size}_w{world_size}_r{rank}.engine'
141+
engine_path = os.path.join(output_dir, engine_name)
142+
with open(engine_path, 'wb') as f:
143+
f.write(engine)
144+
builder.save_config(builder_config, os.path.join(output_dir, 'config.json'))

0 commit comments

Comments
 (0)