Skip to content
67 changes: 67 additions & 0 deletions docs/user_guide/LOAD_CONFIGS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Use Yaml Config File

Cache-DiT now supported load the acceleration configs from a custom yaml file. Here are some examples.

## Single GPU inference

Define a `config.yaml` file that contains:

```yaml
cache_config:
max_warmup_steps: 8
warmup_interval: 2
max_cached_steps: -1
max_continuous_cached_steps: 2
Fn_compute_blocks: 1
Bn_compute_blocks: 0
residual_diff_threshold: 0.12
enable_taylorseer: true
taylorseer_order: 1
```
Then, apply the acceleration config from yaml.

```python
>>> import cache_dit
>>> cache_dit.enable_cache(pipe, **cache_dit.load_configs("config.yaml"))
```

## Distributed inference

Define a `parallel_config.yaml` file that contains:

```yaml
cache_config:
max_warmup_steps: 8
warmup_interval: 2
max_cached_steps: -1
max_continuous_cached_steps: 2
Fn_compute_blocks: 1
Bn_compute_blocks: 0
residual_diff_threshold: 0.12
enable_taylorseer: true
taylorseer_order: 1
parallelism_config:
ulysses_size: auto
parallel_kwargs:
attention_backend: native
extra_parallel_modules: ["text_encoder", "vae"]
```
Then, apply the distributed inference acceleration config from yaml. `ulysses_size: auto` means that cache-dit will auto detect the `world_size` as the ulysses_size. Otherwise, you should mannually set it as specific int number, e.g, 4.
```python
>>> import cache_dit
>>> cache_dit.enable_cache(pipe, **cache_dit.load_configs("parallel_config.yaml"))
```

## Quick Examples

```bash
pip3 install torch==2.9.1 transformers accelerate torchao bitsandbytes torchvision
pip3 install opencv-python-headless einops imageio-ffmpeg ftfy
pip3 install git+https://github.com/huggingface/diffusers.git # latest or >= 0.36.0
pip3 install git+https://github.com/vipshop/cache-dit.git # latest

git clone https://github.com/vipshop/cache-dit.git && cd examples

python3 generate.py flux --config config.yaml
torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py flux --config parallel_config.yaml
```
12 changes: 12 additions & 0 deletions examples/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
cache_config:
max_warmup_steps: 8
warmup_interval: 2
max_cached_steps: -1
max_continuous_cached_steps: 2
Fn_compute_blocks: 1
Bn_compute_blocks: 0
num_inference_steps: 28
steps_computation_mask: fast
residual_diff_threshold: 0.12
enable_taylorseer: true
taylorseer_order: 1
15 changes: 15 additions & 0 deletions examples/parallel_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
cache_config:
max_warmup_steps: 8
warmup_interval: 2
max_cached_steps: -1
max_continuous_cached_steps: 2
Fn_compute_blocks: 1
Bn_compute_blocks: 0
residual_diff_threshold: 0.12
enable_taylorseer: true
taylorseer_order: 1
parallelism_config:
ulysses_size: auto
parallel_kwargs:
attention_backend: native
extra_parallel_modules: ["text_encoder", "vae"]
183 changes: 111 additions & 72 deletions examples/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ def get_args(
default=None,
help="Override mask image path if provided",
)
# Acceleration Config path
parser.add_argument(
"--config-path",
"--config",
type=str,
default=None,
help="Path to CacheDiT configuration YAML file",
)
# Sampling settings
parser.add_argument(
"--prompt",
Expand Down Expand Up @@ -1210,81 +1218,91 @@ def _set_backend(module):
f"Original error: {e}"
) from e

if args.cache or args.parallel_type is not None:

cache_config = kwargs.pop("cache_config", None)
parallelism_config = kwargs.pop("parallelism_config", None)
if args.cache or args.parallel_type is not None or args.config_path is not None:

backend = (
ParallelismBackend.NATIVE_PYTORCH
if args.parallel_type in ["tp"]
else ParallelismBackend.NATIVE_DIFFUSER
)
if args.config_path is None:
# Construct acceleration configs from command line args if config path is not provided
cache_config = kwargs.pop("cache_config", None)
parallelism_config = kwargs.pop("parallelism_config", None)

extra_parallel_modules = prepare_extra_parallel_modules(
args,
pipe_or_adapter,
custom_extra_modules=kwargs.get("extra_parallel_modules", None),
)
backend = (
ParallelismBackend.NATIVE_PYTORCH
if args.parallel_type in ["tp"]
else ParallelismBackend.NATIVE_DIFFUSER
)

parallel_kwargs = {
"attention_backend": ("native" if not args.attn else args.attn),
# e.g., text_encoder_2 in FluxPipeline, text_encoder in Flux2Pipeline
"extra_parallel_modules": extra_parallel_modules,
}
if backend == ParallelismBackend.NATIVE_PYTORCH:
if args.attn is None:
parallel_kwargs["attention_backend"] = None

if backend == ParallelismBackend.NATIVE_DIFFUSER:
parallel_kwargs.update(
{
"experimental_ulysses_anything": args.ulysses_anything,
"experimental_ulysses_float8": args.ulysses_float8,
"experimental_ulysses_async": args.ulysses_async,
}
extra_parallel_modules = prepare_extra_parallel_modules(
args,
pipe_or_adapter,
custom_extra_modules=kwargs.get("extra_parallel_modules", None),
)

# Caching and Parallelism
cache_dit.enable_cache(
pipe_or_adapter,
cache_config=(
DBCacheConfig(
Fn_compute_blocks=args.Fn_compute_blocks,
Bn_compute_blocks=args.Bn_compute_blocks,
max_warmup_steps=args.max_warmup_steps,
warmup_interval=args.warmup_interval,
max_cached_steps=args.max_cached_steps,
max_continuous_cached_steps=args.max_continuous_cached_steps,
residual_diff_threshold=args.residual_diff_threshold,
enable_separate_cfg=kwargs.get("enable_separate_cfg", None),
steps_computation_mask=kwargs.get("steps_computation_mask", None),
)
if cache_config is None and args.cache
else cache_config
),
calibrator_config=(
TaylorSeerCalibratorConfig(
taylorseer_order=args.taylorseer_order,
parallel_kwargs = {
"attention_backend": ("native" if not args.attn else args.attn),
# e.g., text_encoder_2 in FluxPipeline, text_encoder in Flux2Pipeline
"extra_parallel_modules": extra_parallel_modules,
}
if backend == ParallelismBackend.NATIVE_PYTORCH:
if args.attn is None:
parallel_kwargs["attention_backend"] = None

if backend == ParallelismBackend.NATIVE_DIFFUSER:
parallel_kwargs.update(
{
"experimental_ulysses_anything": args.ulysses_anything,
"experimental_ulysses_float8": args.ulysses_float8,
"experimental_ulysses_async": args.ulysses_async,
}
)
if args.taylorseer
else None
),
params_modifiers=kwargs.get("params_modifiers", None),
parallelism_config=(
ParallelismConfig(
ulysses_size=(
dist.get_world_size() if args.parallel_type == "ulysses" else None
),
ring_size=(dist.get_world_size() if args.parallel_type == "ring" else None),
tp_size=(dist.get_world_size() if args.parallel_type == "tp" else None),
backend=backend,
parallel_kwargs=parallel_kwargs,
)
if parallelism_config is None and args.parallel_type in ["ulysses", "ring", "tp"]
else parallelism_config
),
)

# Caching and Parallelism
cache_dit.enable_cache(
pipe_or_adapter,
cache_config=(
DBCacheConfig(
Fn_compute_blocks=args.Fn_compute_blocks,
Bn_compute_blocks=args.Bn_compute_blocks,
max_warmup_steps=args.max_warmup_steps,
warmup_interval=args.warmup_interval,
max_cached_steps=args.max_cached_steps,
max_continuous_cached_steps=args.max_continuous_cached_steps,
residual_diff_threshold=args.residual_diff_threshold,
enable_separate_cfg=kwargs.get("enable_separate_cfg", None),
steps_computation_mask=kwargs.get("steps_computation_mask", None),
)
if cache_config is None and args.cache
else cache_config
),
calibrator_config=(
TaylorSeerCalibratorConfig(
taylorseer_order=args.taylorseer_order,
)
if args.taylorseer
else None
),
params_modifiers=kwargs.get("params_modifiers", None),
parallelism_config=(
ParallelismConfig(
ulysses_size=(
dist.get_world_size() if args.parallel_type == "ulysses" else None
),
ring_size=(dist.get_world_size() if args.parallel_type == "ring" else None),
tp_size=(dist.get_world_size() if args.parallel_type == "tp" else None),
backend=backend,
parallel_kwargs=parallel_kwargs,
)
if parallelism_config is None
and args.parallel_type in ["ulysses", "ring", "tp"]
else parallelism_config
),
)
else:
# Apply acceleration configs from config path
cache_dit.enable_cache(
pipe_or_adapter,
**cache_dit.load_configs(args.config_path),
)
logger.info(f"Applied acceleration from {args.config_path}.")

# Quantization
# WARN: Must apply quantization after tensor parallelism is applied.
Expand Down Expand Up @@ -1338,11 +1356,14 @@ def strify(args, pipe_or_stats):
if args.ulysses_async:
base_str += "_ulysses_async"
if args.parallel_text_encoder:
base_str += "_TEP" # Text Encoder Parallelism
if "_TEP" not in base_str:
base_str += "_TEP" # Text Encoder Parallelism
if args.parallel_vae:
base_str += "_VAEP" # VAE Parallelism
if "_VAEP" not in base_str:
base_str += "_VAEP" # VAE Parallelism
if args.parallel_controlnet:
base_str += "_CNP" # ControlNet Parallelism
if "_CNP" not in base_str:
base_str += "_CNP" # ControlNet Parallelism
if args.attn is not None:
base_str += f"_{args.attn.strip('_')}"
return base_str
Expand Down Expand Up @@ -1376,6 +1397,24 @@ def maybe_init_distributed(args=None):
rank, device = get_rank_device()
current_platform.set_device(device)
return rank, device
elif args.config_path is not None:
# check if distributed is needed from config file
has_parallelism_config = cache_dit.load_parallelism_config(
args.config_path,
check_only=True,
)
if has_parallelism_config:
if not dist.is_initialized():
dist.init_process_group(
backend=backend,
)
rank, device = get_rank_device()
current_platform.set_device(device)
return rank, device
else:
# no distributed needed
rank, device = get_rank_device()
return rank, device
else:
# no distributed needed
rank, device = get_rank_device()
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ nav:
- Low-Bits Quantization: user_guide/QUANTIZATION.md
- Attention Backends: user_guide/ATTENTION.md
- Torch Compile: user_guide/COMPILE.md
- Config with YAML: user_guide/LOAD_CONFIGS.md
- Metrics Tools: user_guide/METRICS.md
- Profiler Usage: user_guide/PROFILER.md
- API Docmentation: user_guide/API_DOCS.md
Expand Down
5 changes: 4 additions & 1 deletion src/cache_dit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

from cache_dit.utils import disable_print
from cache_dit.logger import init_logger
from cache_dit.caching import load_options
from cache_dit.caching import load_options # deprecated
from cache_dit.caching import load_cache_config
from cache_dit.caching import load_parallelism_config
from cache_dit.caching import load_configs
from cache_dit.caching import enable_cache
from cache_dit.caching import refresh_context
from cache_dit.caching import steps_mask
Expand Down
5 changes: 4 additions & 1 deletion src/cache_dit/caching/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,7 @@
from cache_dit.caching.cache_interface import get_adapter
from cache_dit.caching.cache_interface import steps_mask

from cache_dit.caching.utils import load_options
from cache_dit.caching.utils import load_options # deprecated
from cache_dit.caching.utils import load_cache_config
from cache_dit.caching.utils import load_parallelism_config
from cache_dit.caching.utils import load_configs
Loading
Loading