Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,29 +73,45 @@ Here are some example outputs with Flux.1-Schnell for prompt `"A cat playing wit
We rely primarily on pure PyTorch for the optimizations. Currently, a relatively recent nightly version of PyTorch is required.

The numbers reported here were gathered using:

For NVIDIA:
* `torch==2.8.0.dev20250605+cu126` - note that we rely on some fixes since 2.7
* `torchao==0.12.0.dev20250610+cu126` - note that we rely on a fix in the 06/10 nightly
* `diffusers` - with [this fix](https://github.com/huggingface/diffusers/pull/11696) included
* `flash_attn_3==3.0.0b1`

To install deps:
For AMD:
* `torch==2.8.0.dev20250605+rocm6.4` - note that we rely on some fixes since 2.7
* `torchao==0.12.0.dev20250610+rocm6.4` - note that we rely on a fix in the 06/10 nightly
* `diffusers` - with [this fix](https://github.com/huggingface/diffusers/pull/11696) included
* `aiter-0.1.4.dev17+gd0384d4`

To install deps on NVIDIA:
```
pip install -U huggingface_hub[hf_xet] accelerate transformers
pip install -U diffusers
pip install --pre torch==2.8.0.dev20250605+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126
pip install --pre torchao==0.12.0.dev20250609+cu126 --index-url https://download.pytorch.org/whl/nightly/cu126
```

To install flash attention v3, follow the instructions in https://github.com/Dao-AILab/flash-attention#flashattention-3-beta-release.
To install deps on AMD:
```
pip install -U diffusers
pip install --pre torch==2.8.0.dev20250605+rocm6.4 --index-url https://download.pytorch.org/whl/nightly/rocm6.4
pip install --pre torchao==0.12.0.dev20250609+rocm6.4 --index-url https://download.pytorch.org/whl/nightly/rocm6.4
pip install git+https://github.com/ROCm/aiter
```

(For NVIDIA) To install flash attention v3, follow the instructions in https://github.com/Dao-AILab/flash-attention#flashattention-3-beta-release.

For hardware, we used a 96GB 700W H100 GPU. Some of the optimizations applied (BFloat16, torch.compile, Combining q,k,v projections, dynamic float8 quantization) are available on CPU as well.
For hardware, we used a 96GB 700W H100 GPU and 192GB MI300X GPU. Some of the optimizations applied (BFloat16, torch.compile, Combining q,k,v projections, dynamic float8 quantization) are available on CPU as well.

## Run the optimized pipeline

On NVIDIA:
```sh
python gen_image.py --prompt "An astronaut standing next to a giant lemon" --output-file output.png --use-cached-model
```

This will include all optimizations and will attempt to use pre-cached binary models
generated via `torch.export` + AOTI. To generate these binaries for subsequent runs, run
the above command without the `--use-cached-model` flag.
Expand All @@ -108,6 +124,13 @@ the above command without the `--use-cached-model` flag.
> different environment than the one present at runtime. The PyTorch Compiler team is working on
> solutions for more portable binaries / artifact caching.

On AMD:
```sh
python gen_image.py --prompt "A cat playing with a ball of yarn" --output-file output.png --compile_export_mode compile
```
This will include all optimizations except the `torch.export` + AOTI ones.


## Benchmarking
[`run_benchmark.py`](./run_benchmark.py) is the main script for benchmarking the different optimization techniques.
Usage:
Expand Down Expand Up @@ -326,7 +349,7 @@ image = pipe(prompt, num_inference_steps=4).images[0]
</details>

<details>
<summary>Flash Attention V3</summary>
<summary>Flash Attention V3 / aiter</summary>

Flash Attention V3 is substantially faster on H100s than the previous iteration FA2, due
in large part to float8 support. As this kernel isn't quite available yet within PyTorch Core, we implement a custom
Expand All @@ -335,6 +358,8 @@ image = pipe(prompt, num_inference_steps=4).images[0]
the op integrates well with `torch.compile` / `torch.export`. Inputs are converted to float8 in an unscaled fashion before
kernel invocation and outputs are converted back to the original dtype on the way out.

On AMD GPUs, we use [`aiter`](https://github.com/ROCm/aiter) instead, which also provides fp8 MHA kernels.

```python
from diffusers import FluxPipeline

Expand Down
43 changes: 30 additions & 13 deletions utils/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from PIL import Image
import inspect

def is_hip():
return torch.version.hip is not None


@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
def flash_attn_func(
Expand Down Expand Up @@ -34,11 +37,14 @@ def flash_attn_func(
else:
window_size = tuple(window_size)

import flash_attn_interface
if is_hip():
from aiter.ops.triton.mha import flash_attn_fp8_func as flash_attn_interface_func
else:
from flash_attn.flash_attn_interface import flash_attn_interface_func

dtype = torch.float8_e4m3fn

sig = inspect.signature(flash_attn_interface.flash_attn_func)
sig = inspect.signature(flash_attn_interface_func)
accepted = set(sig.parameters)
all_kwargs = {
"softmax_scale": softmax_scale,
Expand All @@ -57,11 +63,11 @@ def flash_attn_func(
}
kwargs = {k: v for k, v in all_kwargs.items() if k in accepted}

outputs = flash_attn_interface.flash_attn_func(
q.to(dtype), k.to(dtype), v.to(dtype), **kwargs,
outputs = flash_attn_interface_func(
q, k, v, **kwargs,
)
return outputs[0]

return outputs.contiguous().to(torch.bfloat16) if is_hip() else outputs[0]

@flash_attn_func.register_fake
def _(q, k, v, **kwargs):
Expand All @@ -71,18 +77,29 @@ def _(q, k, v, **kwargs):
meta_q = torch.empty_like(q).contiguous()
return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)

# fake_input = [torch.randn((1, 4608, 24, 128), device="cuda", dtype=torch.float32) for _ in range(3)]
# torch.library.opcheck(flash_attn_func, fake_input)

# Copied FusedFluxAttnProcessor2_0 but using flash v3 instead of SDPA
class FlashFusedFluxAttnProcessor3_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""

def __init__(self):
try:
import flash_attn_interface
except ImportError:
raise ImportError(
"flash_attention v3 package is required to be installed"
)

if is_hip():
try:
from aiter.ops.triton.mha import flash_attn_fp8_func as flash_attn_interface_func
except ImportError:
raise ImportError(
"aiter is required to be installed"
)
else:
try:
from flash_attn.flash_attn_interface import flash_attn_interface_func
except ImportError:
raise ImportError(
"flash_attention v3 package is required to be installed"
)

def __call__(
self,
Expand Down Expand Up @@ -215,10 +232,10 @@ def use_compile(pipeline):
# Compile the compute-intensive portions of the model: denoising transformer / decoder
is_kontext = "Kontext" in pipeline.__class__.__name__
pipeline.transformer = torch.compile(
pipeline.transformer, mode="max-autotune", fullgraph=True
pipeline.transformer, mode="max-autotune", fullgraph=True, dynamic=True if is_hip() else None
)
pipeline.vae.decode = torch.compile(
pipeline.vae.decode, mode="max-autotune", fullgraph=True
pipeline.vae.decode, mode="max-autotune", fullgraph=True, dynamic=True if is_hip() else None
)

# warmup for a few iterations (`num_inference_steps` shouldn't matter)
Expand Down