diff --git a/docs/user_guide/LOAD_CONFIGS.md b/docs/user_guide/LOAD_CONFIGS.md new file mode 100644 index 000000000..9098d05da --- /dev/null +++ b/docs/user_guide/LOAD_CONFIGS.md @@ -0,0 +1,67 @@ +# Use Yaml Config File + +Cache-DiT now supported load the acceleration configs from a custom yaml file. Here are some examples. + +## Single GPU inference + +Define a `config.yaml` file that contains: + +```yaml +cache_config: + max_warmup_steps: 8 + warmup_interval: 2 + max_cached_steps: -1 + max_continuous_cached_steps: 2 + Fn_compute_blocks: 1 + Bn_compute_blocks: 0 + residual_diff_threshold: 0.12 + enable_taylorseer: true + taylorseer_order: 1 +``` +Then, apply the acceleration config from yaml. + +```python +>>> import cache_dit +>>> cache_dit.enable_cache(pipe, **cache_dit.load_configs("config.yaml")) +``` + +## Distributed inference + +Define a `parallel_config.yaml` file that contains: + +```yaml +cache_config: + max_warmup_steps: 8 + warmup_interval: 2 + max_cached_steps: -1 + max_continuous_cached_steps: 2 + Fn_compute_blocks: 1 + Bn_compute_blocks: 0 + residual_diff_threshold: 0.12 + enable_taylorseer: true + taylorseer_order: 1 +parallelism_config: + ulysses_size: auto + parallel_kwargs: + attention_backend: native + extra_parallel_modules: ["text_encoder", "vae"] +``` +Then, apply the distributed inference acceleration config from yaml. `ulysses_size: auto` means that cache-dit will auto detect the `world_size` as the ulysses_size. Otherwise, you should mannually set it as specific int number, e.g, 4. +```python +>>> import cache_dit +>>> cache_dit.enable_cache(pipe, **cache_dit.load_configs("parallel_config.yaml")) +``` + +## Quick Examples + +```bash +pip3 install torch==2.9.1 transformers accelerate torchao bitsandbytes torchvision +pip3 install opencv-python-headless einops imageio-ffmpeg ftfy +pip3 install git+https://github.com/huggingface/diffusers.git # latest or >= 0.36.0 +pip3 install git+https://github.com/vipshop/cache-dit.git # latest + +git clone https://github.com/vipshop/cache-dit.git && cd examples + +python3 generate.py flux --config config.yaml +torchrun --nproc_per_node=4 --local-ranks-filter=0 generate.py flux --config parallel_config.yaml +``` diff --git a/examples/config.yaml b/examples/config.yaml new file mode 100644 index 000000000..82fe12234 --- /dev/null +++ b/examples/config.yaml @@ -0,0 +1,12 @@ +cache_config: + max_warmup_steps: 8 + warmup_interval: 2 + max_cached_steps: -1 + max_continuous_cached_steps: 2 + Fn_compute_blocks: 1 + Bn_compute_blocks: 0 + num_inference_steps: 28 + steps_computation_mask: fast + residual_diff_threshold: 0.12 + enable_taylorseer: true + taylorseer_order: 1 diff --git a/examples/parallel_config.yaml b/examples/parallel_config.yaml new file mode 100644 index 000000000..cbf482538 --- /dev/null +++ b/examples/parallel_config.yaml @@ -0,0 +1,15 @@ +cache_config: + max_warmup_steps: 8 + warmup_interval: 2 + max_cached_steps: -1 + max_continuous_cached_steps: 2 + Fn_compute_blocks: 1 + Bn_compute_blocks: 0 + residual_diff_threshold: 0.12 + enable_taylorseer: true + taylorseer_order: 1 +parallelism_config: + ulysses_size: auto + parallel_kwargs: + attention_backend: native + extra_parallel_modules: ["text_encoder", "vae"] diff --git a/examples/utils.py b/examples/utils.py index ce382153c..b0765759c 100644 --- a/examples/utils.py +++ b/examples/utils.py @@ -108,6 +108,14 @@ def get_args( default=None, help="Override mask image path if provided", ) + # Acceleration Config path + parser.add_argument( + "--config-path", + "--config", + type=str, + default=None, + help="Path to CacheDiT configuration YAML file", + ) # Sampling settings parser.add_argument( "--prompt", @@ -1210,81 +1218,91 @@ def _set_backend(module): f"Original error: {e}" ) from e - if args.cache or args.parallel_type is not None: - - cache_config = kwargs.pop("cache_config", None) - parallelism_config = kwargs.pop("parallelism_config", None) + if args.cache or args.parallel_type is not None or args.config_path is not None: - backend = ( - ParallelismBackend.NATIVE_PYTORCH - if args.parallel_type in ["tp"] - else ParallelismBackend.NATIVE_DIFFUSER - ) + if args.config_path is None: + # Construct acceleration configs from command line args if config path is not provided + cache_config = kwargs.pop("cache_config", None) + parallelism_config = kwargs.pop("parallelism_config", None) - extra_parallel_modules = prepare_extra_parallel_modules( - args, - pipe_or_adapter, - custom_extra_modules=kwargs.get("extra_parallel_modules", None), - ) + backend = ( + ParallelismBackend.NATIVE_PYTORCH + if args.parallel_type in ["tp"] + else ParallelismBackend.NATIVE_DIFFUSER + ) - parallel_kwargs = { - "attention_backend": ("native" if not args.attn else args.attn), - # e.g., text_encoder_2 in FluxPipeline, text_encoder in Flux2Pipeline - "extra_parallel_modules": extra_parallel_modules, - } - if backend == ParallelismBackend.NATIVE_PYTORCH: - if args.attn is None: - parallel_kwargs["attention_backend"] = None - - if backend == ParallelismBackend.NATIVE_DIFFUSER: - parallel_kwargs.update( - { - "experimental_ulysses_anything": args.ulysses_anything, - "experimental_ulysses_float8": args.ulysses_float8, - "experimental_ulysses_async": args.ulysses_async, - } + extra_parallel_modules = prepare_extra_parallel_modules( + args, + pipe_or_adapter, + custom_extra_modules=kwargs.get("extra_parallel_modules", None), ) - # Caching and Parallelism - cache_dit.enable_cache( - pipe_or_adapter, - cache_config=( - DBCacheConfig( - Fn_compute_blocks=args.Fn_compute_blocks, - Bn_compute_blocks=args.Bn_compute_blocks, - max_warmup_steps=args.max_warmup_steps, - warmup_interval=args.warmup_interval, - max_cached_steps=args.max_cached_steps, - max_continuous_cached_steps=args.max_continuous_cached_steps, - residual_diff_threshold=args.residual_diff_threshold, - enable_separate_cfg=kwargs.get("enable_separate_cfg", None), - steps_computation_mask=kwargs.get("steps_computation_mask", None), - ) - if cache_config is None and args.cache - else cache_config - ), - calibrator_config=( - TaylorSeerCalibratorConfig( - taylorseer_order=args.taylorseer_order, + parallel_kwargs = { + "attention_backend": ("native" if not args.attn else args.attn), + # e.g., text_encoder_2 in FluxPipeline, text_encoder in Flux2Pipeline + "extra_parallel_modules": extra_parallel_modules, + } + if backend == ParallelismBackend.NATIVE_PYTORCH: + if args.attn is None: + parallel_kwargs["attention_backend"] = None + + if backend == ParallelismBackend.NATIVE_DIFFUSER: + parallel_kwargs.update( + { + "experimental_ulysses_anything": args.ulysses_anything, + "experimental_ulysses_float8": args.ulysses_float8, + "experimental_ulysses_async": args.ulysses_async, + } ) - if args.taylorseer - else None - ), - params_modifiers=kwargs.get("params_modifiers", None), - parallelism_config=( - ParallelismConfig( - ulysses_size=( - dist.get_world_size() if args.parallel_type == "ulysses" else None - ), - ring_size=(dist.get_world_size() if args.parallel_type == "ring" else None), - tp_size=(dist.get_world_size() if args.parallel_type == "tp" else None), - backend=backend, - parallel_kwargs=parallel_kwargs, - ) - if parallelism_config is None and args.parallel_type in ["ulysses", "ring", "tp"] - else parallelism_config - ), - ) + + # Caching and Parallelism + cache_dit.enable_cache( + pipe_or_adapter, + cache_config=( + DBCacheConfig( + Fn_compute_blocks=args.Fn_compute_blocks, + Bn_compute_blocks=args.Bn_compute_blocks, + max_warmup_steps=args.max_warmup_steps, + warmup_interval=args.warmup_interval, + max_cached_steps=args.max_cached_steps, + max_continuous_cached_steps=args.max_continuous_cached_steps, + residual_diff_threshold=args.residual_diff_threshold, + enable_separate_cfg=kwargs.get("enable_separate_cfg", None), + steps_computation_mask=kwargs.get("steps_computation_mask", None), + ) + if cache_config is None and args.cache + else cache_config + ), + calibrator_config=( + TaylorSeerCalibratorConfig( + taylorseer_order=args.taylorseer_order, + ) + if args.taylorseer + else None + ), + params_modifiers=kwargs.get("params_modifiers", None), + parallelism_config=( + ParallelismConfig( + ulysses_size=( + dist.get_world_size() if args.parallel_type == "ulysses" else None + ), + ring_size=(dist.get_world_size() if args.parallel_type == "ring" else None), + tp_size=(dist.get_world_size() if args.parallel_type == "tp" else None), + backend=backend, + parallel_kwargs=parallel_kwargs, + ) + if parallelism_config is None + and args.parallel_type in ["ulysses", "ring", "tp"] + else parallelism_config + ), + ) + else: + # Apply acceleration configs from config path + cache_dit.enable_cache( + pipe_or_adapter, + **cache_dit.load_configs(args.config_path), + ) + logger.info(f"Applied acceleration from {args.config_path}.") # Quantization # WARN: Must apply quantization after tensor parallelism is applied. @@ -1338,11 +1356,14 @@ def strify(args, pipe_or_stats): if args.ulysses_async: base_str += "_ulysses_async" if args.parallel_text_encoder: - base_str += "_TEP" # Text Encoder Parallelism + if "_TEP" not in base_str: + base_str += "_TEP" # Text Encoder Parallelism if args.parallel_vae: - base_str += "_VAEP" # VAE Parallelism + if "_VAEP" not in base_str: + base_str += "_VAEP" # VAE Parallelism if args.parallel_controlnet: - base_str += "_CNP" # ControlNet Parallelism + if "_CNP" not in base_str: + base_str += "_CNP" # ControlNet Parallelism if args.attn is not None: base_str += f"_{args.attn.strip('_')}" return base_str @@ -1376,6 +1397,24 @@ def maybe_init_distributed(args=None): rank, device = get_rank_device() current_platform.set_device(device) return rank, device + elif args.config_path is not None: + # check if distributed is needed from config file + has_parallelism_config = cache_dit.load_parallelism_config( + args.config_path, + check_only=True, + ) + if has_parallelism_config: + if not dist.is_initialized(): + dist.init_process_group( + backend=backend, + ) + rank, device = get_rank_device() + current_platform.set_device(device) + return rank, device + else: + # no distributed needed + rank, device = get_rank_device() + return rank, device else: # no distributed needed rank, device = get_rank_device() diff --git a/mkdocs.yml b/mkdocs.yml index e440667be..049876206 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -67,6 +67,7 @@ nav: - Low-Bits Quantization: user_guide/QUANTIZATION.md - Attention Backends: user_guide/ATTENTION.md - Torch Compile: user_guide/COMPILE.md + - Config with YAML: user_guide/LOAD_CONFIGS.md - Metrics Tools: user_guide/METRICS.md - Profiler Usage: user_guide/PROFILER.md - API Docmentation: user_guide/API_DOCS.md diff --git a/src/cache_dit/__init__.py b/src/cache_dit/__init__.py index 70b1f7ba0..4f26b5c74 100644 --- a/src/cache_dit/__init__.py +++ b/src/cache_dit/__init__.py @@ -7,7 +7,10 @@ from cache_dit.utils import disable_print from cache_dit.logger import init_logger -from cache_dit.caching import load_options +from cache_dit.caching import load_options # deprecated +from cache_dit.caching import load_cache_config +from cache_dit.caching import load_parallelism_config +from cache_dit.caching import load_configs from cache_dit.caching import enable_cache from cache_dit.caching import refresh_context from cache_dit.caching import steps_mask diff --git a/src/cache_dit/caching/__init__.py b/src/cache_dit/caching/__init__.py index 9ece75301..d300fff9e 100644 --- a/src/cache_dit/caching/__init__.py +++ b/src/cache_dit/caching/__init__.py @@ -35,4 +35,7 @@ from cache_dit.caching.cache_interface import get_adapter from cache_dit.caching.cache_interface import steps_mask -from cache_dit.caching.utils import load_options +from cache_dit.caching.utils import load_options # deprecated +from cache_dit.caching.utils import load_cache_config +from cache_dit.caching.utils import load_parallelism_config +from cache_dit.caching.utils import load_configs diff --git a/src/cache_dit/caching/cache_interface.py b/src/cache_dit/caching/cache_interface.py index df83e24b6..2c8bc6abf 100644 --- a/src/cache_dit/caching/cache_interface.py +++ b/src/cache_dit/caching/cache_interface.py @@ -307,6 +307,22 @@ def enable_cache( parallelism_config.parallel_kwargs["has_controlnet"] = _has_controlnet( pipe_or_adapter, ) + parallelism_config._has_controlnet = parallelism_config.parallel_kwargs[ + "has_controlnet" + ] + + # Parse extra parallel modules from names to actual modules + if ( + extra_parallel_module := parallelism_config.parallel_kwargs.get( + "extra_parallel_modules", None + ) + ) is not None: + parallelism_config.parallel_kwargs["extra_parallel_modules"] = ( + _parse_extra_parallel_modules( + pipe_or_adapter, + extra_parallel_module, + ) + ) transformers = [] if isinstance(pipe_or_adapter, DiffusionPipeline): @@ -357,6 +373,62 @@ def _has_controlnet(pipe_or_adapter: DiffusionPipeline | BlockAdapter) -> bool: return False +def _parse_text_encoder( + pipe: DiffusionPipeline, +) -> Tuple[Optional[torch.nn.Module], Optional[str]]: + pipe_cls_name = pipe.__class__.__name__ + if ( + hasattr(pipe, "text_encoder_2") + and not pipe_cls_name.startswith("Hunyuan") + and not pipe_cls_name.startswith("Kandinsky") + ): + # Specific for FluxPipeline, FLUX.1-dev + return getattr(pipe, "text_encoder_2"), "text_encoder_2" + elif hasattr(pipe, "text_encoder_3"): # HiDream pipeline + return getattr(pipe, "text_encoder_3"), "text_encoder_3" + elif hasattr(pipe, "text_encoder"): # General case + return getattr(pipe, "text_encoder"), "text_encoder" + else: + return None, None + + +def _parse_extra_parallel_modules( + pipe_or_adapter: DiffusionPipeline | BlockAdapter, + extra_parallel_module: List[str | torch.nn.Module], +) -> Union[List[torch.nn.Module], List]: + if isinstance(pipe_or_adapter, BlockAdapter): + pipe = pipe_or_adapter.pipe + else: + pipe = pipe_or_adapter + + if not extra_parallel_module: # empty list + return [] + + parsed_extra_parallel_modules: List[torch.nn.Module] = [] + for module_or_name in extra_parallel_module: + if isinstance(module_or_name, torch.nn.Module): + parsed_extra_parallel_modules.append(module_or_name) + continue + + if hasattr(pipe, module_or_name): + if module_or_name == "text_encoder": + # Special handling for text encoder + text_encoder, _ = _parse_text_encoder(pipe) + if text_encoder is not None: + parsed_extra_parallel_modules.append(text_encoder) + else: + logger.warning( + "Text encoder not found in the pipeline for extra parallel module." + ) + else: + parsed_extra_parallel_modules.append(getattr(pipe, module_or_name)) + else: + logger.warning( + f"Extra parallel module name {module_or_name} not found in the pipeline." + ) + return parsed_extra_parallel_modules + + def refresh_context( transformer: torch.nn.Module, **force_refresh_kwargs, diff --git a/src/cache_dit/caching/utils.py b/src/cache_dit/caching/utils.py index 11966a912..469f9ae00 100644 --- a/src/cache_dit/caching/utils.py +++ b/src/cache_dit/caching/utils.py @@ -1,5 +1,16 @@ import yaml import copy +from typing import Tuple, Optional, Union +from cache_dit.caching.cache_contexts import ( + DBCacheConfig, + TaylorSeerCalibratorConfig, + DBPruneConfig, + CalibratorConfig, +) +from cache_dit.parallelism import ParallelismConfig, ParallelismBackend +from cache_dit.logger import init_logger + +logger = init_logger(__name__) def load_cache_options_from_dict(cache_kwargs: dict, reset: bool = False) -> dict: @@ -8,10 +19,6 @@ def load_cache_options_from_dict(cache_kwargs: dict, reset: bool = False) -> dic kwargs: dict = copy.deepcopy(cache_kwargs) cache_context_kwargs = {} if kwargs.get("enable_taylorseer", False): - from cache_dit.caching.cache_contexts.calibrators import ( - TaylorSeerCalibratorConfig, - ) - cache_context_kwargs["calibrator_config"] = ( TaylorSeerCalibratorConfig( enable_calibrator=kwargs.get("enable_taylorseer"), @@ -29,23 +36,20 @@ def load_cache_options_from_dict(cache_kwargs: dict, reset: bool = False) -> dic ) if "cache_type" not in kwargs: - from cache_dit.caching.cache_contexts import BasicCacheConfig - + # Assume DBCache if cache_type is not specified cache_context_kwargs["cache_config"] = ( - BasicCacheConfig() if not reset else BasicCacheConfig().reset() + DBCacheConfig() if not reset else DBCacheConfig().reset() ) cache_context_kwargs["cache_config"].update(**kwargs) else: cache_type = str(kwargs.get("cache_type", None)) if cache_type == "DBCache": - from cache_dit.caching.cache_contexts import DBCacheConfig cache_context_kwargs["cache_config"] = ( DBCacheConfig() if not reset else DBCacheConfig().reset() ) cache_context_kwargs["cache_config"].update(**kwargs) elif cache_type == "DBPrune": - from cache_dit.caching.cache_contexts import DBPruneConfig cache_context_kwargs["cache_config"] = ( DBPruneConfig() if not reset else DBPruneConfig().reset() @@ -54,12 +58,6 @@ def load_cache_options_from_dict(cache_kwargs: dict, reset: bool = False) -> dic else: raise ValueError(f"Unsupported cache_type: {cache_type}.") - if "parallelism_config" in kwargs: - from cache_dit.parallelism import ParallelismConfig - - parallelism_kwargs = kwargs.get("parallelism_config", {}) - cache_context_kwargs["parallelism_config"] = ParallelismConfig(**parallelism_kwargs) - return cache_context_kwargs except Exception as e: @@ -91,9 +89,199 @@ def load_options(path_or_dict: str | dict, reset: bool = False) -> dict: Returns: `dict`: A dictionary containing the loaded cache options. """ + # Deprecated function warning + logger.warning( + "`load_options` is deprecated and will be removed in future versions. " + "Please use `load_configs` instead." + ) if isinstance(path_or_dict, str): return load_cache_options_from_yaml(path_or_dict, reset) elif isinstance(path_or_dict, dict): return load_cache_options_from_dict(path_or_dict, reset) else: raise ValueError("Input must be a file path (str) or a configuration dictionary (dict).") + + +def load_cache_config( + path_or_dict: str | dict, **kwargs +) -> Tuple[DBCacheConfig, Optional[CalibratorConfig]]: + r""" + New APU that only load cache configuration from a YAML file or a dictionary. Assumes + that the yaml contains a 'cache_config' section, and returns only that section. + Raise ValueError if not found. + Args: + path_or_dict (`str` or `dict`): + The file path to the YAML configuration file or a dictionary containing the configuration. + reset (`bool`, *optional*, defaults to `False`): + Whether to reset the configuration to default values to None before applying the loaded settings. + This is useful when you want to ensure that only the settings specified in the file or dictionary + are applied, without retaining any previous configurations (e.g., when using ParaModifier to modify + existing configurations). + Returns: + `dict`: A dictionary containing the loaded cache configuration. + """ + if isinstance(path_or_dict, str): + try: + with open(path_or_dict, "r") as f: + kwargs: dict = yaml.safe_load(f) + except FileNotFoundError: + raise FileNotFoundError(f"Configuration file not found: {path_or_dict}") + except yaml.YAMLError as e: + raise yaml.YAMLError(f"YAML file parsing error: {str(e)}") + elif isinstance(path_or_dict, dict): + kwargs: dict = copy.deepcopy(path_or_dict) + else: + raise ValueError("Input must be a file path (str) or a configuration dictionary (dict).") + + if "cache_config" not in kwargs: + raise ValueError("No 'cache_config' section found in the provided configuration.") + + cache_config_kwargs = kwargs["cache_config"] + # Parse steps_mask if exists + if "steps_computation_mask" in cache_config_kwargs: + steps_computation_mask = cache_config_kwargs["steps_computation_mask"] + if isinstance(steps_computation_mask, str): + assert ( + "num_inference_steps" in cache_config_kwargs + ), "To parse steps_mask from str, 'num_inference_steps' must be provided in cache_config." + from .cache_interface import steps_mask + + num_inference_steps = cache_config_kwargs["num_inference_steps"] + cache_config_kwargs["steps_computation_mask"] = steps_mask( + total_steps=num_inference_steps, mask_policy=steps_computation_mask + ) + # Reuse load_cache_options_from_dict to parse cache_config + cache_context_kwargs = load_cache_options_from_dict( + cache_config_kwargs, kwargs.get("reset", False) + ) + cache_config: DBCacheConfig = cache_context_kwargs.get("cache_config", None) + calibrator_config = cache_context_kwargs.get("calibrator_config", None) + if cache_config is None: + raise ValueError("Failed to load 'cache_config'.") + return cache_config, calibrator_config + + +def load_parallelism_config( + path_or_dict: str | dict, **kwargs +) -> Optional[ParallelismConfig] | bool: + r""" + Load parallelism configuration from a YAML file or a dictionary. Assumes that the yaml + contains a 'parallelism_config' section, and returns only that section. Raise ValueError + if not found. + Args: + path_or_dict (`str` or `dict`): + The file path to the YAML configuration file or a dictionary containing the configuration. + Returns: + `ParallelismConfig`: An instance of ParallelismConfig containing the loaded parallelism configuration. + """ + if isinstance(path_or_dict, str): + try: + with open(path_or_dict, "r") as f: + kwargs: dict = yaml.safe_load(f) + except FileNotFoundError: + raise FileNotFoundError(f"Configuration file not found: {path_or_dict}") + except yaml.YAMLError as e: + raise yaml.YAMLError(f"YAML file parsing error: {str(e)}") + elif isinstance(path_or_dict, dict): + kwargs: dict = copy.deepcopy(path_or_dict) + else: + raise ValueError("Input must be a file path (str) or a configuration dictionary (dict).") + + if kwargs.get("check_only", False): + return "parallelism_config" in kwargs + + # Allow missing parallelism_config + if "parallelism_config" not in kwargs: + return None + + parallelism_config_kwargs = kwargs["parallelism_config"] + if "backend" in parallelism_config_kwargs: + backend_str = parallelism_config_kwargs["backend"] + parallelism_config_kwargs["backend"] = ParallelismBackend.from_str(backend_str) + + def _maybe_auto_parallel_size(size: str | int | None) -> Optional[int]: + if size is None: + return None + if isinstance(size, int): + return size + if isinstance(size, str) and size.lower() == "auto": + import torch.distributed as dist + + size = 1 + if dist.is_initialized(): + # Assume world size is the parallel size + size = dist.get_world_size() + if size == 1: + logger.warning( + "Auto parallel size selected as 1. Make sure to run with torch.distributed " + "to utilize multiple devices for parallelism." + ) + else: + logger.info(f"Auto selected parallel size to {size}.") + return size + raise ValueError(f"Invalid parallel size value: {size}. Must be int or 'auto'.") + + if kwargs.get("auto_parallel_size", True): + if "ulysses_size" in parallelism_config_kwargs: + parallelism_config_kwargs["ulysses_size"] = _maybe_auto_parallel_size( + parallelism_config_kwargs["ulysses_size"] + ) + if "ring_size" in parallelism_config_kwargs: + parallelism_config_kwargs["ring_size"] = _maybe_auto_parallel_size( + parallelism_config_kwargs["ring_size"] + ) + if "tp_size" in parallelism_config_kwargs: + parallelism_config_kwargs["tp_size"] = _maybe_auto_parallel_size( + parallelism_config_kwargs["tp_size"] + ) + + parallelism_config = ParallelismConfig(**parallelism_config_kwargs) + return parallelism_config + + +def load_configs( + path_or_dict: str | dict, + return_dict: bool = True, + **kwargs, +) -> Union[Tuple[DBCacheConfig, Optional[CalibratorConfig], ParallelismConfig], dict]: + r""" + Load both cache and parallelism configurations from a YAML file or a dictionary. For example, + the YAML file can be structured as follows: + ```yaml + cache_config: + max_warmup_steps: 8 + warmup_interval: 2 + max_cached_steps: -1 + max_continuous_cached_steps: 2 + Fn_compute_blocks: 1 + Bn_compute_blocks: 0 + residual_diff_threshold: 0.12 + enable_taylorseer: true + taylorseer_order: 1 + parallelism_config: + ulysses_size: 4 + parallel_kwargs: + attention_backend: native + experimental_ulysses_anything: true + experimental_ulysses_float8: true + extra_parallel_modules: ["text_encoder", "vae"] + ``` + Args: + path_or_dict (`str` or `dict`): + The file path to the YAML configuration file or a dictionary containing the configuration. + Returns: + `Tuple[DBCacheConfig, Optional[CalibratorConfig], ParallelismConfig]`: A tuple containing the loaded + cache configuration, optional calibrator configuration, and parallelism configuration. If `return_dict` + is set to `True`, returns a dictionary with keys "cache_config", "calibrator_config", and "parallelism_config". + """ + cache_config, calibrator_config = load_cache_config(path_or_dict, **kwargs) + parallelism_config = load_parallelism_config(path_or_dict, **kwargs) + if isinstance(parallelism_config, bool): + parallelism_config = None + if return_dict: + return { + "cache_config": cache_config, + "calibrator_config": calibrator_config, + "parallelism_config": parallelism_config, + } + return cache_config, calibrator_config, parallelism_config diff --git a/src/cache_dit/parallelism/backend.py b/src/cache_dit/parallelism/backend.py index 6091dccb2..b27179f9f 100644 --- a/src/cache_dit/parallelism/backend.py +++ b/src/cache_dit/parallelism/backend.py @@ -2,13 +2,16 @@ class ParallelismBackend(Enum): + AUTO = "Auto" NATIVE_DIFFUSER = "Native_Diffuser" NATIVE_PYTORCH = "Native_PyTorch" NONE = "None" @classmethod def is_supported(cls, backend: "ParallelismBackend") -> bool: - if backend == cls.NATIVE_PYTORCH: + if backend == cls.AUTO: + return True + elif backend == cls.NATIVE_PYTORCH: return True elif backend == cls.NATIVE_DIFFUSER: try: @@ -25,5 +28,12 @@ def is_supported(cls, backend: "ParallelismBackend") -> bool: return True return False + @classmethod + def from_str(cls, backend_str: str) -> "ParallelismBackend": + for backend in cls: + if backend.value.lower() == backend_str.lower(): + return backend + raise ValueError(f"Unsupported parallelism backend: {backend_str}.") + def __str__(self) -> str: return self.value diff --git a/src/cache_dit/parallelism/config.py b/src/cache_dit/parallelism/config.py index 636a1fcf8..e1533ee41 100644 --- a/src/cache_dit/parallelism/config.py +++ b/src/cache_dit/parallelism/config.py @@ -8,8 +8,9 @@ @dataclasses.dataclass class ParallelismConfig: - # Parallelism backend, defaults to NATIVE_DIFFUSER - backend: ParallelismBackend = ParallelismBackend.NATIVE_DIFFUSER + # Parallelism backend, defaults to AUTO. We will auto select the backend + # based on the parallelism configuration. + backend: ParallelismBackend = ParallelismBackend.AUTO # Context parallelism config # ulysses_size (`int`, *optional*): # The degree of ulysses parallelism. @@ -26,12 +27,27 @@ class ParallelismConfig: # NATIVE_DIFFUSER backend, it can include `cp_plan` and # `attention_backend` arguments for `Context Parallelism`. parallel_kwargs: Optional[Dict[str, Any]] = dataclasses.field(default_factory=dict) + # Some internal fields for utils usage + _has_text_encoder: bool = False + _has_auto_encoder: bool = False + _has_controlnet: bool = False def __post_init__(self): assert ParallelismBackend.is_supported(self.backend), ( f"Parallel backend {self.backend} is not supported. " f"Please make sure the required packages are installed." ) + if self.backend == ParallelismBackend.AUTO: + # Auto select the backend based on the parallelism configuration + if (self.ulysses_size is not None and self.ulysses_size > 1) or ( + self.ring_size is not None and self.ring_size > 1 + ): + self.backend = ParallelismBackend.NATIVE_DIFFUSER + elif self.tp_size is not None and self.tp_size > 1: + self.backend = ParallelismBackend.NATIVE_PYTORCH + else: + self.backend = ParallelismBackend.NONE + logger.info(f"Auto selected parallelism backend for transformer: {self.backend}") # Validate the parallelism configuration and auto adjust the backend if needed if self.tp_size is not None and self.tp_size > 1: @@ -67,6 +83,13 @@ def __post_init__(self): ) self.backend = ParallelismBackend.NATIVE_DIFFUSER + def enabled(self) -> bool: + return ( + (self.ulysses_size is not None and self.ulysses_size > 1) + or (self.ring_size is not None and self.ring_size > 1) + or (self.tp_size is not None and self.tp_size > 1) + ) + def strify( self, details: bool = False, @@ -108,11 +131,11 @@ def strify( parallel_str += f"Ring{self.ring_size}" if self.tp_size is not None: parallel_str += f"TP{self.tp_size}" - if text_encoder: + if text_encoder or self._has_text_encoder: parallel_str += "_TEP" # Text Encoder Parallelism - if vae: + if vae or self._has_auto_encoder: parallel_str += "_VAEP" # VAE Parallelism - if controlnet: + if controlnet or self._has_controlnet: parallel_str += "_CNP" # ControlNet Parallelism return parallel_str @@ -137,6 +160,7 @@ def text_encoder_world_size(self) -> int: assert ( world_size is None or world_size > 1 ), "Text encoder world size must be None or greater than 1 for parallelism." + self._has_text_encoder = True return world_size @property @@ -146,8 +170,19 @@ def auto_encoder_world_size(self) -> int: assert ( world_size is None or world_size > 1 ), "VAE world size must be None or greater than 1 for parallelism." + self._has_auto_encoder = True return world_size @property def vae_world_size(self) -> int: # alias of auto_encoder_world_size return self.vae_world_size + + @property + def controlnet_world_size(self) -> int: + """Get the world size for ControlNet parallelism.""" + world_size = self._get_extra_module_world_size() + assert ( + world_size is None or world_size > 1 + ), "ControlNet world size must be None or greater than 1 for parallelism." + self._has_controlnet = True + return world_size